diff --git a/egs/speech_llm/ASR_LLM/run.sh b/egs/speech_llm/ASR_LLM/run.sh index 7a6c39631..123f74c03 100755 --- a/egs/speech_llm/ASR_LLM/run.sh +++ b/egs/speech_llm/ASR_LLM/run.sh @@ -1,15 +1,14 @@ -export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall +export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm # pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html # pip install -r whisper/requirements.txt - -method=mask_predict -# method=cif_ar_distill_embedding -torchrun --nproc_per_node 8 ./parawhisper/train.py \ - --max-duration 200 \ - --exp-dir parawhisper/exp_large_v2_${method} \ - --model-name large-v2 \ +export CUDA_VISIBLE_DEVICES=0,1 +torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \ + --max-duration 80 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name tiny \ + --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ --manifest-dir data/fbank \ - --method $method \ --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json \ No newline at end of file + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn False \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index e2b9d7ecc..e5096cb4d 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -1,7 +1,8 @@ from torch import nn import torch +from transformers.trainer_pt_utils import LabelSmoother -DEFAULT_SPEECH_TOKEN = -1997 # "" +IGNORE_TOKEN_ID = LabelSmoother.ignore_index class EncoderProjector(nn.Module): @@ -18,7 +19,6 @@ class EncoderProjector(nn.Module): return x class SPEECH_LLM(nn.Module): - # https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/slam_model.py def __init__( self, encoder: nn.Module, @@ -28,8 +28,12 @@ class SPEECH_LLM(nn.Module): super().__init__() self.encoder = encoder + for name, param in encoder.named_parameters(): + param.requires_grad = False self.encoder.eval() self.llm = llm + for name, param in llm.named_parameters(): + param.requires_grad = False self.llm.eval() self.encoder_projector = encoder_projector self.encoder_outputs_downsample_rate = 4 @@ -39,11 +43,11 @@ class SPEECH_LLM(nn.Module): batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)) # 1. Create a mask to know where special speech tokens are - special_speech_token_mask = input_ids == DEFAULT_SPEECH_TOKEN + special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length - batch_indices, non_speech_indices = torch.where(input_ids != DEFAULT_SPEECH_TOKEN) + batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged speech-text sequence. @@ -65,7 +69,7 @@ class SPEECH_LLM(nn.Module): ) if labels is not None: final_labels = torch.full( - (batch_size, max_embed_dim), self., dtype=input_ids.dtype, device=input_ids.device + (batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. @@ -128,17 +132,17 @@ class SPEECH_LLM(nn.Module): speech_features, inputs_embeds, input_ids, attention_mask, labels ) - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - logits = outputs[0] + # outputs = self.llm( + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_values=past_key_values, + # inputs_embeds=inputs_embeds, + # use_cache=use_cache, + # output_attentions=output_attentions, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, + # ) + # logits = outputs[0] model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index ad48c6bf0..f6eacab01 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -255,4 +255,16 @@ class MultiDataset: self.fbank_dir / "aishell_cuts_train.jsonl.gz" ) - return aishell_cuts \ No newline at end of file + return aishell_cuts + + + def aishell_dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # AISHELL + logging.info("Loading Aishell set in lazy mode") + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + return aishell_dev_cuts \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 720b2c6ff..b3674be01 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -39,13 +39,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union import deepspeed import k2 -import optim +# import optim import torch import torch.multiprocessing as mp import torch.nn as nn import whisper from asr_datamodule import AsrDataModule -from model import SPEECH_LLM, EncoderProjector +from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from label_smoothing import LabelSmoothingLoss from lhotse import CutSet, load_manifest @@ -59,7 +59,7 @@ from torch.cuda.amp import GradScaler from torch.nn.functional import pad as pad_tensor # from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward + from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from icefall import diagnostics @@ -78,13 +78,12 @@ from icefall.utils import ( ) from transformers import AutoModelForCausalLM, AutoTokenizer -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +import transformers +from transformers.trainer_pt_utils import LabelSmoother +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +DEFAULT_SPEECH_TOKEN = "" - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module +def set_batch_count(model: nn.Module, batch_count: float) -> None: for module in model.modules(): if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -240,6 +239,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-flash-attn", + type=str2bool, + default=True, + help="Whether to use flash attention.", + ) + parser = deepspeed.add_config_arguments(parser) add_model_arguments(parser) @@ -272,6 +278,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "allowed_excess_duration_ratio": 0.1, + "subsampling_factor": 2, "frame_shift_ms": 10, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -357,7 +364,7 @@ def get_params() -> AttributeDict: def compute_loss( params: AttributeDict, tokenizer: AutoTokenizer, - model: Union[nn.Module, DDP], + model: nn.Module, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -397,7 +404,6 @@ def compute_loss( texts.append( tokenizer.apply_chat_template( msg, - chat_template=TEMPLATE, tokenize=True, add_generation_prompt=False, padding="max_length", @@ -405,8 +411,9 @@ def compute_loss( truncation=True, ) ) - # model_inputs = tokenizer([text], return_tensors="pt").to(device) + input_ids = torch.tensor(texts, dtype=torch.int) + # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] target_ids = input_ids.clone() target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(tokenizer.pad_token_id) @@ -453,47 +460,49 @@ def compute_loss( allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) batch = filter_uneven_sized_batch(batch, allowed_max_frames) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) - # feature_lens = supervisions["num_frames"].to(device) - batch_idx_train = params.batch_idx_train supervisions = batch["supervisions"] texts = batch["supervisions"]["text"] # remove spaces in texts texts = [normalize_text_alimeeting(text) for text in texts] - # messages = [ - # {"role": "system", "content": "You are a helpful assistant."}, - # {"role": "user", "content": prompt} - # ] + messages = [] + for i, text in enumerate(texts): + message = [ + {"role": "system", "content": "你是一个能处理音频的助手。"}, + {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": text}, + ] + messages.append(message) + input_ids, attention_mask, target_ids = preprocess( - texts, tokenizer, max_len=512 + messages, tokenizer, max_len=128 ) - # decoder_criterion = LabelSmoothingLoss( - # ignore_index=50256, label_smoothing=0.1, reduction="sum" - # ) + target_ids = target_ids.type(torch.LongTensor) + input_ids = input_ids.type(torch.LongTensor) - # # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> - # ignore_prefix_size = 3 - # with torch.set_grad_enabled(is_training): - # encoder_out = model.encoder(feature) - # text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) - # text_logits = text_logits[:, ignore_prefix_size:, :] - # target_tokens = target_tokens[:, ignore_prefix_size:] - # loss = decoder_criterion(text_logits, target_tokens.to(device)) - - # assert loss.requires_grad == is_training + with torch.set_grad_enabled(is_training): + model_outpus = model( + fbank=feature, + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + ) + loss = model_outpus.loss + assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") + feature_lens = supervisions["num_frames"] info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. @@ -505,7 +514,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, tokenizer: whisper.tokenizer.Tokenizer, - model: Union[nn.Module, DDP], + model: nn.Module, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -540,9 +549,9 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, tokenizer: AutoTokenizer, - model: Union[nn.Module, DDP], + model: nn.Module, optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, + scheduler: torch.optim.lr_scheduler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, @@ -609,18 +618,26 @@ def train_one_epoch( model.save_checkpoint( save_dir=params.exp_dir, tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - client_state={"sampler": train_dl.sampler.state_dict()}, + client_state={}, + exclude_frozen_parameters=True ) + if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + exclude_frozen_parameters=True, + ) + # save sampler state dict into checkpoint + sampler_state_dict = train_dl.sampler.state_dict() + torch.save( + sampler_state_dict, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", ) os.system( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) - try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -698,19 +715,32 @@ def run(rank, world_size, args): logging.info("About to create model") - if 'whisper' in params.speech_encoder_path_or_name: - replace_whisper_encoder_forward() - # TODO: directly loading from whisper-ft checkpoint - # whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - speech_encoder = whisper.load_model(params.model_name, "cpu").encoder - speech_encoder_dim = speech_encoder.dims.n_audio_ctx + # if 'whisper' in params.speech_encoder_path_or_name: + replace_whisper_encoder_forward() + # TODO: directly loading from whisper-ft checkpoint + # whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + + else: + attn_implementation = "eager" llm = AutoModelForCausalLM.from_pretrained( params.llm_path_or_name, - attn_implemented="flash_attention_2", - device_map="cpu" + attn_implementation=attn_implementation, ) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + tokenizer.padding_side = 'left' + special_tokens_dict = { + "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] + } + tokenizer.add_special_tokens(special_tokens_dict) + llm.config.pad_token_id = tokenizer.pad_token_id + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) @@ -723,6 +753,11 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + logging.info("Trainable parameters (excluding model.eval modules):") + for name, param in model.named_parameters(): + if param.requires_grad: + logging.info(f"{name}: {param.shape}") + if torch.cuda.is_available(): device = torch.device("cuda", rank) else: @@ -770,12 +805,14 @@ def run(rank, world_size, args): # sampler_state_dict = checkpoints["sampler"] # else: # sampler_state_dict = None + sampler_state_dict = None # TODO: load sampler state dict train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = multi_dataset.dev_cuts() + # valid_cuts = multi_dataset.dev_cuts() + valid_cuts = multi_dataset.aishell_dev_cuts() valid_dl = data_module.valid_dataloaders(valid_cuts) if args.tensorboard and rank == 0: @@ -818,14 +855,20 @@ def run(rank, world_size, args): model.save_checkpoint( save_dir=params.exp_dir, tag=f"epoch-{params.cur_epoch}", - client_state={"sampler": train_dl.sampler.state_dict()}, + client_state={}, + exclude_frozen_parameters=True ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", tag=f"epoch-{params.cur_epoch}", + exclude_frozen_parameters=True, ) + # save sampler state dict into checkpoint + sampler_state_dict = train_dl.sampler.state_dict() + torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt") + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")