mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add speech continuation pretraining
This commit is contained in:
parent
e65725810c
commit
f81363d324
@ -122,7 +122,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
log "stage 1: Compute fbank feature from huggingface"
|
log "stage 6: Compute fbank feature from huggingface"
|
||||||
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
# CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||||
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
# --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||||
# --out-dir data/fbank_voice_assistant \
|
# --out-dir data/fbank_voice_assistant \
|
||||||
@ -161,10 +161,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
--subset test --split test \
|
--subset test --split test \
|
||||||
--audio-key audio --text-key text \
|
--audio-key audio --text-key text \
|
||||||
--prefix gigaspeech
|
--prefix gigaspeech
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|
||||||
log "stage 9: Compute fbank feature from huggingface"
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \
|
||||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
|
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \
|
||||||
--out-dir data/fbank_gigaspeech \
|
--out-dir data/fbank_gigaspeech \
|
||||||
@ -195,7 +192,7 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||||
log "stage 11: Decoding EN, only support batch_size=1 for now."
|
log "stage 11: Decoding EN, val set only support batch_size=1 for now."
|
||||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||||
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
# cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||||
python3 ./qwen_omni/decode.py \
|
python3 ./qwen_omni/decode.py \
|
||||||
@ -256,3 +253,71 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
|||||||
--output-dir test_result
|
--output-dir test_result
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||||
|
log "stage 15: Training Speech2Speech Model, adaptor only"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2text
|
||||||
|
ngpu=2
|
||||||
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
|
--max-duration 600 \
|
||||||
|
--enable-musan False \
|
||||||
|
--audio-key audio --text-key continuation \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
|
--on-the-fly-feats True \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--dataset-format speech_continuation \
|
||||||
|
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
|
||||||
|
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
||||||
|
log "stage 16: Training Speech2Speech Model, adaptor only"
|
||||||
|
exp_dir=./qwen_omni/exp_speech2text
|
||||||
|
ngpu=4
|
||||||
|
|
||||||
|
latest_checkpoint_step=-1
|
||||||
|
# Check if exp_dir exists and is a directory
|
||||||
|
if [ -d "$exp_dir" ]; then
|
||||||
|
# List directories matching checkpoint-* and find the one with the largest step number
|
||||||
|
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||||
|
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||||
|
# Extract step number using parameter expansion
|
||||||
|
current_step=${checkpoint_name#checkpoint-}
|
||||||
|
# Ensure current_step is a number
|
||||||
|
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||||
|
latest_checkpoint_step=$current_step
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
train_cmd_args="--max-duration 1200 \
|
||||||
|
--enable-musan False \
|
||||||
|
--audio-key audio --text-key continuation \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
|
--on-the-fly-feats True \
|
||||||
|
--deepspeed \
|
||||||
|
--huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \
|
||||||
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--dataset-format speech_continuation \
|
||||||
|
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False"
|
||||||
|
|
||||||
|
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||||
|
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||||
|
step=$latest_checkpoint_step
|
||||||
|
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||||
|
else
|
||||||
|
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||||
|
# No pretrained model or sampler state dict needed for the first run
|
||||||
|
fi
|
||||||
|
|
||||||
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
|
$train_cmd_args
|
||||||
|
fi
|
||||||
|
@ -24,7 +24,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import interleave_datasets, load_dataset
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
WhisperFbank,
|
WhisperFbank,
|
||||||
@ -36,6 +36,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
|
PerturbSpeed,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
@ -47,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from speech_dataset import K2SpeechRecognitionDataset
|
from speech_dataset import K2SpeechRecognitionDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils import str2bool
|
from utils import get_rank, str2bool
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
@ -123,6 +124,14 @@ class AsrDataModule:
|
|||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available.",
|
"if available.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-speed-perturb",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, use on-the-fly speed perturbation. "
|
||||||
|
"Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -188,27 +197,27 @@ class AsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--huggingface-dataset-path-or-name",
|
"--huggingface-dataset-path-or-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="/workspace/Belle_1.4M-SLAM-Omni",
|
default=None,
|
||||||
help="The path or name of the Huggingface dataset",
|
help="The path or name of the Huggingface dataset",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--audio-key",
|
"--audio-key",
|
||||||
type=str,
|
type=str,
|
||||||
default="question_audio",
|
default="audio",
|
||||||
help="The key in the Huggingface dataset containing the audio data",
|
help="The key in the Huggingface dataset containing the audio data",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--text-key",
|
"--text-key",
|
||||||
type=str,
|
type=str,
|
||||||
default="answer",
|
default="text",
|
||||||
help="The key in the Huggingface dataset containing the text data",
|
help="The key in the Huggingface dataset containing the text data",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
# group.add_argument(
|
||||||
"--resample-to-16kHz",
|
# "--resample-to-16kHz",
|
||||||
type=str2bool,
|
# type=str2bool,
|
||||||
default=True,
|
# default=True,
|
||||||
help="Resample audio to 16kHz. Default: False.",
|
# help="Resample audio to 16kHz. Default: False.",
|
||||||
)
|
# )
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
@ -232,6 +241,8 @@ class AsrDataModule:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats:
|
||||||
|
transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
|
||||||
|
|
||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
@ -260,9 +271,11 @@ class AsrDataModule:
|
|||||||
logging.info("Disable SpecAugment")
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
|
rank = get_rank()
|
||||||
|
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||||
)
|
)
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
@ -271,26 +284,6 @@ class AsrDataModule:
|
|||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if self.args.on_the_fly_feats:
|
|
||||||
# # NOTE: the PerturbSpeed transform should be added only if we
|
|
||||||
# # remove it from data prep stage.
|
|
||||||
# # Add on-the-fly speed perturbation; since originally it would
|
|
||||||
# # have increased epoch size by 3, we will apply prob 2/3 and use
|
|
||||||
# # 3x more epochs.
|
|
||||||
# # Speed perturbation probably should come first before
|
|
||||||
# # concatenation, but in principle the transforms order doesn't have
|
|
||||||
# # to be strict (e.g. could be randomized)
|
|
||||||
# # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
|
||||||
# # Drop feats to be on the safe side.
|
|
||||||
# train = K2SpeechRecognitionDataset(
|
|
||||||
# cut_transforms=transforms,
|
|
||||||
# input_strategy=OnTheFlyFeatures(
|
|
||||||
# WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
|
||||||
# ),
|
|
||||||
# input_transforms=input_transforms,
|
|
||||||
# return_cuts=self.args.return_cuts,
|
|
||||||
# )
|
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
train_sampler = DynamicBucketingSampler(
|
train_sampler = DynamicBucketingSampler(
|
||||||
@ -298,8 +291,7 @@ class AsrDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
buffer_size=self.args.num_buckets * 2000,
|
buffer_size=self.args.num_buckets * 1000,
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -339,10 +331,10 @@ class AsrDataModule:
|
|||||||
CutSet for validation.
|
CutSet for validation.
|
||||||
"""
|
"""
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
|
rank = get_rank()
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}"))
|
||||||
)
|
)
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
@ -470,25 +462,231 @@ class AsrDataModule:
|
|||||||
)
|
)
|
||||||
return {"test": VoiceAssistant_cuts}
|
return {"test": VoiceAssistant_cuts}
|
||||||
|
|
||||||
# def train_cuts_en_vocalnet(self) -> CutSet:
|
@lru_cache()
|
||||||
|
def train_cuts_ultravox(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
if self.args.huggingface_dataset_path_or_name is not None:
|
||||||
|
librispeech_path = (
|
||||||
|
self.args.huggingface_dataset_path_or_name + "/librispeech_asr"
|
||||||
|
)
|
||||||
|
people_speech_path = (
|
||||||
|
self.args.huggingface_dataset_path_or_name + "/peoples_speech"
|
||||||
|
)
|
||||||
|
gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech"
|
||||||
|
else:
|
||||||
|
librispeech_path = "fixie-ai/librispeech_asr"
|
||||||
|
people_speech_path = "fixie-ai/peoples_speech"
|
||||||
|
gigaspeech_path = "fixie-ai/gigaspeech"
|
||||||
|
# 148_688
|
||||||
|
librispeech_other = load_dataset(
|
||||||
|
librispeech_path, "other", split="train.500", streaming=True
|
||||||
|
)
|
||||||
|
# 104_014
|
||||||
|
librispeech_clean_360 = load_dataset(
|
||||||
|
librispeech_path, "clean", split="train.360", streaming=True
|
||||||
|
)
|
||||||
|
# 28_539
|
||||||
|
librispeech_clean_100 = load_dataset(
|
||||||
|
librispeech_path, "clean", split="train.100", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1_501_271
|
||||||
|
people_speech_clean = load_dataset(
|
||||||
|
people_speech_path, "clean", split="train", streaming=True
|
||||||
|
)
|
||||||
|
# 548_000
|
||||||
|
people_speech_dirty_sa = load_dataset(
|
||||||
|
people_speech_path, "dirty_sa", split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8_266_422
|
||||||
|
|
||||||
|
gigaspeech = load_dataset(
|
||||||
|
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_100,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_other,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_360,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||||
|
)
|
||||||
|
|
||||||
|
people_speech_clean_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
people_speech_clean,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
people_speech_dirty_sa,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CutSet.mux(
|
||||||
|
librispeech_clean_100_cuts,
|
||||||
|
librispeech_clean_360_cuts,
|
||||||
|
librispeech_other_cuts,
|
||||||
|
gigaspeech_cuts,
|
||||||
|
people_speech_clean_cuts,
|
||||||
|
people_speech_dirty_sa_cuts,
|
||||||
|
weights=[
|
||||||
|
28539,
|
||||||
|
104014,
|
||||||
|
148688,
|
||||||
|
8266422,
|
||||||
|
1501271,
|
||||||
|
548000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# @lru_cache()
|
||||||
|
# def train_cuts_ultravox(self) -> CutSet:
|
||||||
# logging.info("About to get train cuts")
|
# logging.info("About to get train cuts")
|
||||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
# keep_columns = ["audio", "text", "continuation", "id"]
|
||||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
# librispeech_path="fixie-ai/librispeech_asr"
|
||||||
# )
|
# # 148_688
|
||||||
# return VoiceAssistant_cuts
|
# librispeech_other = load_dataset(librispeech_path, 'other', split='train.500', streaming=True)
|
||||||
|
# # 104_014
|
||||||
|
# librispeech_clean_360 = load_dataset(librispeech_path, 'clean', split='train.360', streaming=True)
|
||||||
|
# # 28_539
|
||||||
|
# librispeech_clean_100 = load_dataset(librispeech_path, 'clean', split='train.100', streaming=True)
|
||||||
|
|
||||||
# @lru_cache()
|
# cols_to_remove = librispeech_clean_100.column_names
|
||||||
# def valid_cuts_en_vocalnet(self) -> CutSet:
|
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
|
||||||
# logging.info("About to get valid cuts")
|
# librispeech_clean_100 = librispeech_clean_100.remove_columns(cols_to_remove)
|
||||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
# librispeech_clean_360 = librispeech_clean_360.remove_columns(cols_to_remove)
|
||||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
# librispeech_other = librispeech_other.remove_columns(cols_to_remove)
|
||||||
# )
|
# people_speech_path="fixie-ai/peoples_speech"
|
||||||
# return VoiceAssistant_cuts
|
# # 1_501_271
|
||||||
|
# people_speech_clean = load_dataset(people_speech_path, 'clean', split='train', streaming=True)
|
||||||
|
# # 548_000
|
||||||
|
# people_speech_dirty_sa = load_dataset(people_speech_path, 'dirty_sa', split='train', streaming=True)
|
||||||
|
# cols_to_remove = people_speech_clean.column_names
|
||||||
|
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
|
||||||
|
# people_speech_clean = people_speech_clean.remove_columns(cols_to_remove)
|
||||||
|
# people_speech_dirty_sa = people_speech_dirty_sa.remove_columns(cols_to_remove)
|
||||||
|
|
||||||
# @lru_cache()
|
# # 8_266_422
|
||||||
# def test_cuts_en_vocalnet(self) -> CutSet:
|
# gigaspeech_path="fixie-ai/gigaspeech"
|
||||||
# logging.info("About to get test cuts")
|
# gigaspeech = load_dataset(gigaspeech_path, 'xl-empty-audio-removed', split='train', streaming=True)
|
||||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
# # first rename segment_id to id
|
||||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
# gigaspeech = gigaspeech.rename_column("segment_id", "id")
|
||||||
|
# cols_to_remove = gigaspeech.column_names
|
||||||
|
# cols_to_remove = [col for col in cols_to_remove if col not in keep_columns]
|
||||||
|
# gigaspeech = gigaspeech.remove_columns(cols_to_remove)
|
||||||
|
|
||||||
|
# total_item = 104014 + 28539 + 8266422 + 1501271 + 548000 + 148688
|
||||||
|
# final_datasets = interleave_datasets([
|
||||||
|
# librispeech_clean_100,
|
||||||
|
# librispeech_clean_360,
|
||||||
|
# gigaspeech,
|
||||||
|
# people_speech_clean,
|
||||||
|
# people_speech_dirty_sa,
|
||||||
|
# librispeech_other,
|
||||||
|
# ], probabilities=[
|
||||||
|
# 28539 / total_item,
|
||||||
|
# 104014 / total_item,
|
||||||
|
# 8266422 / total_item,
|
||||||
|
# 1501271 / total_item,
|
||||||
|
# 548000 / total_item,
|
||||||
|
# 148688 / total_item,
|
||||||
|
# ])
|
||||||
|
|
||||||
|
# train_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
# final_datasets, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||||
# )
|
# )
|
||||||
# return VoiceAssistant_cuts
|
|
||||||
|
# return train_cuts
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts_ultravox(self) -> CutSet:
|
||||||
|
logging.info("About to get valid cuts")
|
||||||
|
librispeech_path = "fixie-ai/librispeech_asr"
|
||||||
|
librispeech_clean_valid = load_dataset(
|
||||||
|
librispeech_path, "clean", split="validation", streaming=True
|
||||||
|
)
|
||||||
|
librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_valid,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
return librispeech_clean_valid_cuts
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_librispeech(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
|
||||||
|
# librispeech_path="fixie-ai/librispeech_asr"
|
||||||
|
librispeech_path = "/workspace/slam/librispeech_asr"
|
||||||
|
# 148_688
|
||||||
|
librispeech_other = load_dataset(
|
||||||
|
librispeech_path, "other", split="train.500", streaming=True
|
||||||
|
)
|
||||||
|
# 104_014
|
||||||
|
librispeech_clean_360 = load_dataset(
|
||||||
|
librispeech_path, "clean", split="train.360", streaming=True
|
||||||
|
)
|
||||||
|
# 28_539
|
||||||
|
librispeech_clean_100 = load_dataset(
|
||||||
|
librispeech_path, "clean", split="train.100", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_clean_100_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_100,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_other_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_other,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech_clean_360_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
librispeech_clean_360,
|
||||||
|
audio_key=self.args.audio_key,
|
||||||
|
text_key=self.args.text_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CutSet.mux(
|
||||||
|
librispeech_clean_100_cuts,
|
||||||
|
librispeech_clean_360_cuts,
|
||||||
|
librispeech_other_cuts,
|
||||||
|
weights=[
|
||||||
|
28539,
|
||||||
|
104014,
|
||||||
|
148688,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_gigaspeech(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
gigaspeech_path = "fixie-ai/gigaspeech"
|
||||||
|
gigaspeech = load_dataset(
|
||||||
|
gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
gigaspeech_cuts = CutSet.from_huggingface_dataset(
|
||||||
|
gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||||
|
)
|
||||||
|
|
||||||
|
return gigaspeech_cuts
|
||||||
|
@ -10,3 +10,4 @@ transformers>=4.37.0
|
|||||||
flash-attn
|
flash-attn
|
||||||
peft
|
peft
|
||||||
torchmetrics
|
torchmetrics
|
||||||
|
triton==3.3.0 # may be violate with openai-whisper
|
||||||
|
@ -68,24 +68,26 @@ from transformers import (
|
|||||||
Qwen2Config,
|
Qwen2Config,
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
)
|
)
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
|
||||||
|
|
||||||
# from icefall import diagnostics
|
|
||||||
from utils import get_rank, get_world_size
|
|
||||||
# from icefall.env import get_env_info
|
# from icefall.env import get_env_info
|
||||||
|
# from icefall import diagnostics
|
||||||
from utils import ( # filter_uneven_sized_batch,
|
from utils import ( # filter_uneven_sized_batch,
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
try:
|
try:
|
||||||
torch.multiprocessing.set_start_method('spawn')
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, "batch_count"):
|
if hasattr(module, "batch_count"):
|
||||||
@ -272,7 +274,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 5000,
|
"valid_interval": 3000,
|
||||||
# "env_info": get_env_info(),
|
# "env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -332,6 +334,21 @@ def process_batch_vocalnet(batch: dict):
|
|||||||
return messages, answer_cosyvoice_speech_token
|
return messages, answer_cosyvoice_speech_token
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch_speech_continuation(batch: dict):
|
||||||
|
messages = []
|
||||||
|
for i in range(len(batch["supervisions"]["text"])):
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
|
||||||
|
]
|
||||||
|
# transcript = batch["supervisions"]["cut"][i].custom["text"]
|
||||||
|
messages.append(message)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
@ -429,13 +446,13 @@ def compute_loss(
|
|||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
feature = feature.transpose(1, 2) # (N, C, T)
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
|
||||||
|
|
||||||
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
|
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
|
||||||
if params.dataset_format == "slam_omni":
|
if params.dataset_format == "slam_omni":
|
||||||
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
|
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
|
||||||
elif params.dataset_format == "vocalnet":
|
elif params.dataset_format == "vocalnet":
|
||||||
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
|
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
|
||||||
|
elif params.dataset_format == "speech_continuation":
|
||||||
|
messages = process_batch_speech_continuation(batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||||
|
|
||||||
@ -566,8 +583,11 @@ def train_one_epoch(
|
|||||||
The rank of the node in DDP training. If no DDP is used, it should
|
The rank of the node in DDP training. If no DDP is used, it should
|
||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.encoder_projector.train()
|
# model.encoder_projector.train()
|
||||||
|
model.train()
|
||||||
|
model.encoder.eval()
|
||||||
|
if not params.unfreeze_llm:
|
||||||
|
model.llm.eval()
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
@ -583,6 +603,9 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
|
model.encoder.eval()
|
||||||
|
if not params.unfreeze_llm:
|
||||||
|
model.llm.eval()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
@ -594,7 +617,7 @@ def train_one_epoch(
|
|||||||
if batch_idx != 0:
|
if batch_idx != 0:
|
||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"zero-checkpoint-{params.batch_idx_train}",
|
||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
@ -602,18 +625,18 @@ def train_one_epoch(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"zero-checkpoint-{params.batch_idx_train}",
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
# save sampler state dict into checkpoint
|
# save sampler state dict into checkpoint
|
||||||
sampler_state_dict = train_dl.sampler.state_dict()
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
torch.save(
|
torch.save(
|
||||||
sampler_state_dict,
|
sampler_state_dict,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
|
||||||
)
|
)
|
||||||
os.system(
|
os.system(
|
||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||||
@ -687,9 +710,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
if rank == 0:
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
@ -698,7 +721,6 @@ def run(rank, world_size, args):
|
|||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
for name, param in speech_encoder.named_parameters():
|
for name, param in speech_encoder.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
speech_encoder.eval()
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
|
|
||||||
@ -721,7 +743,7 @@ def run(rank, world_size, args):
|
|||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
for name, param in llm.named_parameters():
|
for name, param in llm.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
llm.eval()
|
|
||||||
else:
|
else:
|
||||||
if params.use_lora:
|
if params.use_lora:
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
@ -809,6 +831,9 @@ def run(rank, world_size, args):
|
|||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||||
|
# set params.batch_idx_train according to the checkpoint name
|
||||||
|
if "checkpoint-" in params.pretrained_model_path:
|
||||||
|
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -842,21 +867,22 @@ def run(rank, world_size, args):
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 30.0:
|
if c.duration < 1.0 or c.duration > 29.5:
|
||||||
# logging.warning(
|
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
|
||||||
# )
|
|
||||||
return False
|
|
||||||
codec_len = (
|
|
||||||
len(c.custom["answer_cosyvoice_speech_token"])
|
|
||||||
if "answer_cosyvoice_speech_token" in c.custom
|
|
||||||
else len(c.custom["speech_token"])
|
|
||||||
)
|
|
||||||
if codec_len > 2200:
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom:
|
||||||
|
codec_len = (
|
||||||
|
len(c.custom["answer_cosyvoice_speech_token"])
|
||||||
|
if "answer_cosyvoice_speech_token" in c.custom
|
||||||
|
else len(c.custom["speech_token"])
|
||||||
|
)
|
||||||
|
if codec_len > 2200:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if params.dataset_format == "slam_omni":
|
if params.dataset_format == "slam_omni":
|
||||||
@ -865,6 +891,11 @@ def run(rank, world_size, args):
|
|||||||
elif params.dataset_format == "vocalnet":
|
elif params.dataset_format == "vocalnet":
|
||||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||||
|
elif params.dataset_format == "speech_continuation":
|
||||||
|
# train_cuts = data_module.train_cuts_ultravox()
|
||||||
|
# train_cuts = data_module.train_cuts_gigaspeech()
|
||||||
|
train_cuts = data_module.train_cuts_librispeech()
|
||||||
|
valid_cuts = data_module.valid_cuts_ultravox()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||||
|
|
||||||
@ -879,7 +910,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
# train_dl = data_module.valid_dataloaders(train_cuts)
|
||||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
@ -913,25 +944,25 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"zero-epoch-{params.cur_epoch}",
|
||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"zero-epoch-{params.cur_epoch}",
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
# save sampler state dict into checkpoint
|
# save sampler state dict into checkpoint
|
||||||
sampler_state_dict = train_dl.sampler.state_dict()
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
torch.save(
|
torch.save(
|
||||||
sampler_state_dict,
|
sampler_state_dict,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
@ -971,6 +1002,7 @@ def main():
|
|||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
run(rank=rank, world_size=world_size, args=args)
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user