From 271536248f10de74dc8a06b51ac57c3e44d91b8e Mon Sep 17 00:00:00 2001 From: root Date: Tue, 11 Jun 2024 09:04:29 +0000 Subject: [PATCH] fix decoding issue and padding to longest --- .../ASR_LLM/whisper_llm_zh/decode.py | 28 +++++++---- .../ASR_LLM/whisper_llm_zh/model.py | 47 +++++++++---------- .../ASR_LLM/whisper_llm_zh/train.py | 30 ++++++++---- 3 files changed, 61 insertions(+), 44 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 745cf5104..603a6e0af 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -196,7 +196,7 @@ def get_parser(): default=True, help="Whether to use flash attention.", ) - + add_model_arguments(parser) return parser @@ -238,7 +238,7 @@ def decode_one_batch( ) -> Dict: """Preprocesses the data for supervised fine-tuning.""" texts = [] - TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" for i, msg in enumerate(messages): texts.append( tokenizer.apply_chat_template( @@ -246,11 +246,16 @@ def decode_one_batch( tokenize=True, add_generation_prompt=False, chat_template=TEMPLATE, - padding="max_length", + padding="longest", max_length=max_len, truncation=True, ) ) + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == 'right': + texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts] + else: + texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts] input_ids = torch.tensor(texts, dtype=torch.int) @@ -481,33 +486,36 @@ def main(): whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") speech_encoder = whisper_model.encoder speech_encoder_dim = whisper_model.dims.n_audio_state - + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + if params.use_flash_attn: attn_implementation = "flash_attention_2" # torch_dtype=torch.bfloat16 torch_dtype=torch.float16 + tokenizer.padding_side = 'left' else: attn_implementation = "eager" torch_dtype=torch.float16 + tokenizer.padding_side = 'right' llm = AutoModelForCausalLM.from_pretrained( params.llm_path_or_name, attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) - tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) - # tokenizer.padding_side = 'left' + special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] } tokenizer.add_special_tokens(special_tokens_dict) - llm.config.pad_token_id = tokenizer.pad_token_id - llm.config.bos_token_id = tokenizer.bos_token_id - llm.config.eos_token_id = tokenizer.eos_token_id + llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") + llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) - encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) + encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) model = SPEECH_LLM( speech_encoder, diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 7175fab17..df86f87af 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -1,12 +1,12 @@ from torch import nn import torch from transformers.trainer_pt_utils import LabelSmoother - +from icefall.dist import get_rank IGNORE_TOKEN_ID = LabelSmoother.ignore_index class EncoderProjector(nn.Module): # https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py - def __init__(self, encoder_dim, llm_dim, downsample_rate=1): + def __init__(self, encoder_dim, llm_dim, downsample_rate=5): super().__init__() self.downsample_rate = downsample_rate self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim) @@ -47,7 +47,6 @@ class SPEECH_LLM(nn.Module): param.requires_grad = False self.llm.eval() self.encoder_projector = encoder_projector - self.encoder_outputs_downsample_rate = 4 def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None): num_speechs, speech_len, embed_dim = speech_features.shape @@ -134,33 +133,39 @@ class SPEECH_LLM(nn.Module): labels: torch.LongTensor = None, ): encoder_outs = self.encoder(fbank) - # downsample encoder_outs by 4 - # encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] speech_features = self.encoder_projector(encoder_outs) inputs_embeds = self.llm.get_input_embeddings()(input_ids) + + enable_logging = False + rank = get_rank() - # print("input_ids", input_ids, input_ids.shape) - # print("labels", labels, labels.shape) - # print("inputs_embeds", inputs_embeds.shape, inputs_embeds) - # print("attention_mask_before", attention_mask.shape, attention_mask) - # print(2333333333333333333333333333) + # log only on rank 0, training using deep + if enable_logging and rank == 0: + print("input_ids", input_ids, input_ids.shape) + print("labels", labels, labels.shape) + print("inputs_embeds", inputs_embeds.shape, inputs_embeds) + print("attention_mask_before", attention_mask.shape, attention_mask) + print(2333333333333333333333333333) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) - # print("labels", labels, labels.shape) - # print("speech_features", speech_features.shape, speech_features) - # print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) - # print("attention_mask", attention_mask.shape, attention_mask) - # print("position_ids", position_ids.shape, position_ids) - # print("================================================================") - # input() + if enable_logging and rank == 0: + print("speech_features", speech_features.shape, speech_features) + print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) + print("attention_mask", attention_mask.shape, attention_mask) + print("position_ids", position_ids.shape, position_ids) + print("labels", labels, labels.shape) + print("================================================================") model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) # model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) + if enable_logging and rank == 0: + print("preds", preds, preds.shape) + print(4555555555555555555555555555555555555555555) acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) return model_outputs, acc @@ -173,9 +178,6 @@ class SPEECH_LLM(nn.Module): ): encoder_outs = self.encoder(fbank) - # downsample encoder_outs by 4 - # encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] - speech_features = self.encoder_projector(encoder_outs) speech_features = speech_features.to(torch.float16) inputs_embeds = self.llm.get_input_embeddings()(input_ids) @@ -196,10 +198,7 @@ class SPEECH_LLM(nn.Module): eos_token_id=self.llm.config.eos_token_id, pad_token_id=self.llm.config.pad_token_id ) - # print(generated_ids, input_ids) - # generated_ids = [ - # output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids) - # ] + return generated_ids diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 1f6d4abad..10023ec9a 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-projector-ds-rate", type=int, - default=4, + default=1, help="Downsample rate for the encoder projector.", ) @@ -287,7 +287,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 10000, + "valid_interval": 5000, "env_info": get_env_info(), } ) @@ -408,12 +408,17 @@ def compute_loss( tokenize=True, chat_template=TEMPLATE, add_generation_prompt=False, - padding="max_length", + padding="longest", # FIX me change padding to longest max_length=max_len, truncation=True, ) ) - + # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == 'right': + texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts] + else: + texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts] input_ids = torch.tensor(texts, dtype=torch.int) # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] target_ids = input_ids.clone() @@ -423,8 +428,7 @@ def compute_loss( mask_prompt = True if mask_prompt: mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant")) - # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 - # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID + # then mask all tokens before the first token e.g. 151646 (speech), 151645 , 198 \n for i in range(mask_indices[0].size(0)): row = mask_indices[0][i] col = mask_indices[1][i] @@ -526,7 +530,7 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["acc"] = acc + info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames return loss, info @@ -743,22 +747,24 @@ def run(rank, world_size, args): speech_encoder = whisper_model.encoder speech_encoder_dim = whisper_model.dims.n_audio_state + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: attn_implementation = "flash_attention_2" # torch_dtype=torch.bfloat16 torch_dtype=torch.float16 + tokenizer.padding_side = 'left' else: attn_implementation = "eager" torch_dtype=torch.float16 + tokenizer.padding_side = 'right' llm = AutoModelForCausalLM.from_pretrained( params.llm_path_or_name, attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) - tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) - # tokenizer.padding_side = 'left' + special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] } @@ -766,7 +772,7 @@ def run(rank, world_size, args): llm.config.pad_token_id = tokenizer.pad_token_id llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) - encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) + encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) model = SPEECH_LLM( speech_encoder, @@ -774,6 +780,10 @@ def run(rank, world_size, args): encoder_projector, ) + if params.pretrained_model_path: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + model.load_state_dict(checkpoint, strict=False) + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")