mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
support text2speech ultrachat
This commit is contained in:
parent
49256fa917
commit
4c0396f8f2
@ -3,7 +3,7 @@
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||
# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||
set -eou pipefail
|
||||
|
||||
stage=$1
|
||||
@ -123,7 +123,9 @@ fi
|
||||
|
||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
log "stage 19: Training TTS Model"
|
||||
exp_dir=./qwen_omni/exp_tts
|
||||
exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
|
||||
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
|
||||
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
|
||||
pretrained_dir=./qwen_omni/exp_speech2text
|
||||
ngpu=4
|
||||
|
||||
@ -141,17 +143,16 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
train_cmd_args="--batch-size 64 \
|
||||
# --dataset ultra_chat_voice_assistant
|
||||
train_cmd_args="--batch-size 30 \
|
||||
--exp-dir $exp_dir \
|
||||
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--enable-speech-input False \
|
||||
--deepspeed \
|
||||
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/emilia_en \
|
||||
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--num-epochs 2 \
|
||||
--num-epochs 3 \
|
||||
--use-lora False --unfreeze-llm False --enable-speech-output True"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
@ -168,66 +169,66 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "stage 20: Training TTS Model"
|
||||
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
||||
if [ ! -L "/workspace/slam" ]; then
|
||||
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
fi
|
||||
exp_dir=./qwen_omni/exp_test
|
||||
ngpu=4
|
||||
# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
# log "stage 20: Training TTS Model"
|
||||
# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
||||
# if [ ! -L "/workspace/slam" ]; then
|
||||
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
# fi
|
||||
# exp_dir=./qwen_omni/exp_test
|
||||
# ngpu=4
|
||||
|
||||
latest_checkpoint_step=-1
|
||||
# Check if exp_dir exists and is a directory
|
||||
if [ -d "$exp_dir" ]; then
|
||||
# List directories matching checkpoint-* and find the one with the largest step number
|
||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# Extract step number using parameter expansion
|
||||
current_step=${checkpoint_name#checkpoint-}
|
||||
# Ensure current_step is a number
|
||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
latest_checkpoint_step=$current_step
|
||||
fi
|
||||
done
|
||||
fi
|
||||
# latest_checkpoint_step=-1
|
||||
# # Check if exp_dir exists and is a directory
|
||||
# if [ -d "$exp_dir" ]; then
|
||||
# # List directories matching checkpoint-* and find the one with the largest step number
|
||||
# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||
# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||
# # Extract step number using parameter expansion
|
||||
# current_step=${checkpoint_name#checkpoint-}
|
||||
# # Ensure current_step is a number
|
||||
# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||
# latest_checkpoint_step=$current_step
|
||||
# fi
|
||||
# done
|
||||
# fi
|
||||
|
||||
train_cmd_args="--max-duration 150 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset vocalnet_ultrachat_voiceassistant \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True --on-the-fly-feats True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
|
||||
# train_cmd_args="--max-duration 150 \
|
||||
# --enable-musan False \
|
||||
# --exp-dir $exp_dir \
|
||||
# --speech-encoder-path-or-name models/large-v2.pt \
|
||||
# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
# --dataset vocalnet_ultrachat_voiceassistant \
|
||||
# --manifest-dir data/fbank \
|
||||
# --deepspeed \
|
||||
# --deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
# --use-flash-attn True --on-the-fly-feats True \
|
||||
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
|
||||
|
||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
step=$latest_checkpoint_step
|
||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
else
|
||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# No pretrained model or sampler state dict needed for the first run
|
||||
fi
|
||||
# if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||
# log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||
# step=$latest_checkpoint_step
|
||||
# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||
# else
|
||||
# log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||
# # No pretrained model or sampler state dict needed for the first run
|
||||
# fi
|
||||
|
||||
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||
$train_cmd_args
|
||||
fi
|
||||
# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||
# $train_cmd_args
|
||||
# fi
|
||||
|
||||
|
||||
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||
log "stage 21: TTS Decoding Test Set"
|
||||
exp_dir=./qwen_omni/exp_tts
|
||||
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
--use-lora True
|
||||
fi
|
||||
# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||
# log "stage 21: TTS Decoding Test Set"
|
||||
# exp_dir=./qwen_omni/exp_tts
|
||||
# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
||||
# --exp-dir $exp_dir \
|
||||
# --speech-encoder-path-or-name models/large-v2.pt \
|
||||
# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
||||
# --use-flash-attn True \
|
||||
# --enable-speech-output True \
|
||||
# --token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
# --use-lora True
|
||||
# fi
|
||||
|
@ -437,7 +437,8 @@ class SPEECH_LLM(nn.Module):
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||
# text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||
text_input_embeds_list = []
|
||||
for i in range(len(text_label_start_index_list)):
|
||||
text_last_hidden = model_outputs.hidden_states[-1][
|
||||
i,
|
||||
@ -445,14 +446,14 @@ class SPEECH_LLM(nn.Module):
|
||||
+ input_seq_len[i]
|
||||
- 1,
|
||||
]
|
||||
text_last_hidden_lists.append(text_last_hidden)
|
||||
# text_last_hidden_lists.append(text_last_hidden)
|
||||
text_embed = inputs_embeds[
|
||||
i,
|
||||
text_input_start_index_list[i]
|
||||
+ 1 : text_input_start_index_list[i]
|
||||
+ input_seq_len[i],
|
||||
] # exclude bos
|
||||
text_embeds_list.append(text_embed)
|
||||
# text_embeds_list.append(text_embed)
|
||||
|
||||
text_input_embeds = torch.cat(
|
||||
[
|
||||
@ -480,8 +481,9 @@ class SPEECH_LLM(nn.Module):
|
||||
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||
logging.warning(
|
||||
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
|
||||
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
|
||||
)
|
||||
# breakpoint()
|
||||
text_input_embeds = text_input_embeds[
|
||||
: audio_embeddings.shape[1] - start_idx
|
||||
]
|
||||
|
@ -68,6 +68,7 @@ from transformers import (
|
||||
)
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
from pathlib import Path
|
||||
|
||||
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
@ -171,11 +172,14 @@ def data_collator(batch):
|
||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": item["text"]},
|
||||
]
|
||||
# message_list_item += [
|
||||
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "assistant", "content": item["text"]},
|
||||
# ]
|
||||
messages.append(message_list_item)
|
||||
durations.append(item["duration"])
|
||||
ids.append(item["id"])
|
||||
ids.append(item["index"] if "index" in item else item["id"])
|
||||
lang.append(item["language"])
|
||||
dnsmos.append(item["dnsmos"])
|
||||
|
||||
return {
|
||||
"speech_tokens": speech_tokens,
|
||||
@ -183,7 +187,92 @@ def data_collator(batch):
|
||||
"durations": durations,
|
||||
"ids": ids,
|
||||
"lang": lang,
|
||||
"dnsmos": dnsmos,
|
||||
}
|
||||
|
||||
def data_collator_concate_items(batch, concat_items_num: int = 3):
|
||||
"""Concatenate *concat_items_num* consecutive dataset items into one.
|
||||
|
||||
The function groups the incoming ``batch`` (a list of dataset items)
|
||||
into non-overlapping chunks of *concat_items_num*. For each group it
|
||||
concatenates the textual fields and speech codec tokens so that the
|
||||
model generates one longer utterance instead of several short ones.
|
||||
|
||||
Any remainder (when ``len(batch)`` is not divisible by
|
||||
*concat_items_num*) is also kept as a smaller group.
|
||||
"""
|
||||
|
||||
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
|
||||
grouped_ids, grouped_lang = [], []
|
||||
|
||||
# Iterate over the batch in strides of *concat_items_num*
|
||||
for start_idx in range(0, len(batch), concat_items_num):
|
||||
group = batch[start_idx : start_idx + concat_items_num]
|
||||
if not group:
|
||||
continue
|
||||
|
||||
# 1) Speech tokens --------------------------------------------------
|
||||
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
|
||||
# element to decide how to concatenate.
|
||||
first_code = group[0]["code"]
|
||||
if isinstance(first_code, torch.Tensor):
|
||||
concat_code = torch.cat([item["code"] for item in group], dim=0)
|
||||
else:
|
||||
# assume list / iterable of ints
|
||||
concat_code = []
|
||||
for item in group:
|
||||
concat_code.extend(item["code"])
|
||||
|
||||
# 2) Text -----------------------------------------------------------
|
||||
concat_text = "".join([item["text"] for item in group])
|
||||
|
||||
# 3) Build chat template messages -----------------------------------
|
||||
message_list_item = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": concat_text},
|
||||
]
|
||||
|
||||
# 4) Misc meta fields ----------------------------------------------
|
||||
total_duration = sum(item["duration"] for item in group)
|
||||
group_ids = [item.get("index", item.get("id")) for item in group]
|
||||
language = group[0].get("language", "")
|
||||
|
||||
# 5) Append to output lists ----------------------------------------
|
||||
grouped_speech_tokens.append(concat_code)
|
||||
grouped_messages.append(message_list_item)
|
||||
grouped_durations.append(total_duration)
|
||||
grouped_ids.append(group_ids)
|
||||
grouped_lang.append(language)
|
||||
|
||||
return {
|
||||
"speech_tokens": grouped_speech_tokens,
|
||||
"messages": grouped_messages,
|
||||
"durations": grouped_durations,
|
||||
"ids": grouped_ids,
|
||||
"lang": grouped_lang,
|
||||
}
|
||||
|
||||
def data_collator_ultra_chat(batch):
|
||||
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
speech_tokens.append(item["custom"]["speech_token"])
|
||||
text = item["supervisions"][0]["text"]
|
||||
message_list_item = []
|
||||
message_list_item += [
|
||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": text},
|
||||
]
|
||||
messages.append(message_list_item)
|
||||
durations.append(item["duration"])
|
||||
ids.append(item["id"])
|
||||
|
||||
return {
|
||||
"speech_tokens": speech_tokens,
|
||||
"messages": messages,
|
||||
"durations": durations,
|
||||
"ids": ids,
|
||||
}
|
||||
|
||||
def compute_loss(
|
||||
@ -470,13 +559,21 @@ def run(rank, world_size, args):
|
||||
sampler_state_dict = None
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
# print(params.dataset)
|
||||
ds = load_dataset(params.dataset, split="train")
|
||||
# shuffle the dataset
|
||||
ds = ds.shuffle(seed=42)
|
||||
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
||||
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
|
||||
if params.dataset == "ultra_chat_voice_assistant":
|
||||
data_dir = "data/fbank"
|
||||
json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
|
||||
ds = load_dataset("json", data_files=json_file_lists, split="train")
|
||||
# shuffle the dataset
|
||||
train_dataset = ds.shuffle(seed=42)
|
||||
eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
|
||||
else:
|
||||
data_dir = Path(params.dataset)
|
||||
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
|
||||
ds = load_dataset("json", data_files=json_file_lists, split="train")
|
||||
# shuffle the dataset
|
||||
ds = ds.shuffle(seed=42)
|
||||
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
||||
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||
|
||||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||
train_dl = StatefulDataLoader(
|
||||
@ -486,7 +583,7 @@ def run(rank, world_size, args):
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
prefetch_factor=2,
|
||||
collate_fn=data_collator
|
||||
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
|
||||
)
|
||||
train_dl.load_state_dict(sampler_state_dict)
|
||||
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
||||
@ -497,7 +594,7 @@ def run(rank, world_size, args):
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
prefetch_factor=1,
|
||||
collate_fn=data_collator
|
||||
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
|
||||
)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
|
@ -12,12 +12,12 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import numpy as np
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
@ -431,3 +431,45 @@ def write_error_stats(
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def get_dataset_statistics(dataset, save_filename):
|
||||
speech_token_lengths = []
|
||||
text_lengths = []
|
||||
for item in tqdm(dataset):
|
||||
if 'custom' not in item:
|
||||
speech_token = item["code"]
|
||||
text = item["text"]
|
||||
else:
|
||||
speech_token = item["custom"]["speech_token"]
|
||||
text = item["supervisions"][0]["text"]
|
||||
speech_token_lengths.append(len(speech_token))
|
||||
text_lengths.append(len(text))
|
||||
speech_token_length_array = np.array(speech_token_lengths)
|
||||
text_length_array = np.array(text_lengths)
|
||||
# 计算并存储统计指标
|
||||
def get_length_stats(lengths_array):
|
||||
length_stats = []
|
||||
length_stats.append(["count", f"{len(lengths_array)}"]) # 总数
|
||||
length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"])
|
||||
length_stats.append(["std", f"{np.std(lengths_array):.1f}"])
|
||||
length_stats.append(["min", f"{np.min(lengths_array):.1f}"])
|
||||
length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"])
|
||||
length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的
|
||||
length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"])
|
||||
length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"])
|
||||
length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"])
|
||||
length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"])
|
||||
length_stats.append(["max", f"{np.max(lengths_array):.1f}"])
|
||||
return length_stats
|
||||
speech_length_stats = get_length_stats(speech_token_length_array)
|
||||
text_length_stats = get_length_stats(text_length_array)
|
||||
with open(save_filename, "w") as f:
|
||||
print("speech_tokens 长度统计指标:", file=f)
|
||||
for stat_name, stat_value in speech_length_stats:
|
||||
print(f"{stat_name:<15}: {stat_value}", file=f)
|
||||
print("\ntext 长度统计指标:", file=f)
|
||||
for stat_name, stat_value in text_length_stats:
|
||||
print(f"{stat_name:<15}: {stat_value}", file=f)
|
||||
|
||||
return speech_token_lengths, text_lengths
|
||||
|
Loading…
x
Reference in New Issue
Block a user