From 412e926941ce2038c4781d1c0b8871701a43fea6 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Jun 2024 13:25:29 +0000 Subject: [PATCH] fix down sample method --- egs/speech_llm/ASR_LLM/debug.sh | 2 +- .../whisper_llm_zh/ds_config_zero1.json | 39 +++++++++++++- .../ASR_LLM/whisper_llm_zh/model.py | 51 +++++++++++++++---- .../ASR_LLM/whisper_llm_zh/train.py | 14 +++-- 4 files changed, 91 insertions(+), 15 deletions(-) mode change 120000 => 100644 egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json diff --git a/egs/speech_llm/ASR_LLM/debug.sh b/egs/speech_llm/ASR_LLM/debug.sh index 644167275..6cdf03da9 100755 --- a/egs/speech_llm/ASR_LLM/debug.sh +++ b/egs/speech_llm/ASR_LLM/debug.sh @@ -4,7 +4,7 @@ export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm # pip install -r whisper/requirements.txt export CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \ - --max-duration 1 \ + --max-duration 20 \ --exp-dir ./whisper_llm_zh/exp_test \ --speech-encoder-path-or-name tiny \ --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json deleted file mode 120000 index af7162d6c..000000000 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json +++ /dev/null @@ -1 +0,0 @@ -../../../aishell/ASR/whisper/ds_config_zero1.json \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json new file mode 100644 index 000000000..730937a21 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json @@ -0,0 +1,38 @@ +{ + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 100, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 0.01 + }, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-4, + "warmup_num_steps": 100 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 5, + "steps_per_print": 50, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false +} diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index eec9d7812..3fc7c654b 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -5,14 +5,25 @@ from transformers.trainer_pt_utils import LabelSmoother IGNORE_TOKEN_ID = LabelSmoother.ignore_index class EncoderProjector(nn.Module): - - def __init__(self, encoder_dim, llm_dim): +# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py + def __init__(self, encoder_dim, llm_dim, downsample_rate=4): super().__init__() - self.linear1 = nn.Linear(encoder_dim, llm_dim) + self.downsample_rate = downsample_rate + self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim) self.relu = nn.ReLU() self.linear2 = nn.Linear(llm_dim, llm_dim) - def forward(self, x): + def forward(self, x): + + batch_size, seq_len, feat_dim = x.size() + num_frames_to_discard = seq_len % self.downsample_rate + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate) + x = self.linear1(x) x = self.relu(x) x = self.linear2(x) @@ -124,7 +135,7 @@ class SPEECH_LLM(nn.Module): ): encoder_outs = self.encoder(fbank) # downsample encoder_outs by 4 - encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] + # encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] speech_features = self.encoder_projector(encoder_outs) @@ -138,8 +149,10 @@ class SPEECH_LLM(nn.Module): #print("speech_features", speech_features.shape) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) - - return model_outputs + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) + return model_outputs, acc def decode(self, @@ -151,7 +164,7 @@ class SPEECH_LLM(nn.Module): encoder_outs = self.encoder(fbank) # downsample encoder_outs by 4 - encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] + # encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] speech_features = self.encoder_projector(encoder_outs) speech_features = speech_features.to(torch.float16) @@ -177,4 +190,24 @@ class SPEECH_LLM(nn.Module): # generated_ids = [ # output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids) # ] - return generated_ids \ No newline at end of file + return generated_ids + + +def compute_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (LongTensor): Prediction tensors (B, Lmax). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index dab19ca8b..3272ce7f3 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -420,7 +420,11 @@ def compute_loss( # first get the indices of the tokens mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 - target_ids[mask_indices[0], :mask_indices[1]+4] = IGNORE_TOKEN_ID + # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + target_ids[row, :col+4] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(tokenizer.pad_token_id) @@ -496,13 +500,13 @@ def compute_loss( input_ids = input_ids.type(torch.LongTensor) with torch.set_grad_enabled(is_training): - model_outpus = model( + model_outputs, acc = model( fbank=feature, input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=target_ids.to(device), ) - loss = model_outpus.loss + loss = model_outputs.loss assert loss.requires_grad == is_training info = MetricsTracker() @@ -513,6 +517,7 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["acc"] = acc return loss, info @@ -731,7 +736,8 @@ def run(rank, world_size, args): if params.use_flash_attn: attn_implementation = "flash_attention_2" - torch_dtype=torch.bfloat16 + # torch_dtype=torch.bfloat16 + torch_dtype=torch.float16 else: attn_implementation = "eager"