mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix decoding issue and padding to longest
This commit is contained in:
parent
eb2c255e1e
commit
271536248f
@ -196,7 +196,7 @@ def get_parser():
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
@ -238,7 +238,7 @@ def decode_one_batch(
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
@ -246,11 +246,16 @@ def decode_one_batch(
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="max_length",
|
||||
padding="longest",
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == 'right':
|
||||
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
||||
else:
|
||||
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
@ -481,33 +486,36 @@ def main():
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'left'
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'right'
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
# tokenizer.padding_side = 'left'
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.bos_token_id = tokenizer.bos_token_id
|
||||
llm.config.eos_token_id = tokenizer.eos_token_id
|
||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
|
@ -1,12 +1,12 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
from icefall.dist import get_rank
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=1):
|
||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
||||
@ -47,7 +47,6 @@ class SPEECH_LLM(nn.Module):
|
||||
param.requires_grad = False
|
||||
self.llm.eval()
|
||||
self.encoder_projector = encoder_projector
|
||||
self.encoder_outputs_downsample_rate = 4
|
||||
|
||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||
@ -134,33 +133,39 @@ class SPEECH_LLM(nn.Module):
|
||||
labels: torch.LongTensor = None,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
# downsample encoder_outs by 4
|
||||
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
enable_logging = False
|
||||
rank = get_rank()
|
||||
|
||||
# print("input_ids", input_ids, input_ids.shape)
|
||||
# print("labels", labels, labels.shape)
|
||||
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||
# print("attention_mask_before", attention_mask.shape, attention_mask)
|
||||
# print(2333333333333333333333333333)
|
||||
# log only on rank 0, training using deep
|
||||
if enable_logging and rank == 0:
|
||||
print("input_ids", input_ids, input_ids.shape)
|
||||
print("labels", labels, labels.shape)
|
||||
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||
print("attention_mask_before", attention_mask.shape, attention_mask)
|
||||
print(2333333333333333333333333333)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
# print("labels", labels, labels.shape)
|
||||
# print("speech_features", speech_features.shape, speech_features)
|
||||
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||
# print("attention_mask", attention_mask.shape, attention_mask)
|
||||
# print("position_ids", position_ids.shape, position_ids)
|
||||
# print("================================================================")
|
||||
# input()
|
||||
if enable_logging and rank == 0:
|
||||
print("speech_features", speech_features.shape, speech_features)
|
||||
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||
print("attention_mask", attention_mask.shape, attention_mask)
|
||||
print("position_ids", position_ids.shape, position_ids)
|
||||
print("labels", labels, labels.shape)
|
||||
print("================================================================")
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
||||
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
if enable_logging and rank == 0:
|
||||
print("preds", preds, preds.shape)
|
||||
print(4555555555555555555555555555555555555555555)
|
||||
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
|
||||
return model_outputs, acc
|
||||
|
||||
@ -173,9 +178,6 @@ class SPEECH_LLM(nn.Module):
|
||||
):
|
||||
|
||||
encoder_outs = self.encoder(fbank)
|
||||
# downsample encoder_outs by 4
|
||||
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
speech_features = speech_features.to(torch.float16)
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
@ -196,10 +198,7 @@ class SPEECH_LLM(nn.Module):
|
||||
eos_token_id=self.llm.config.eos_token_id,
|
||||
pad_token_id=self.llm.config.pad_token_id
|
||||
)
|
||||
# print(generated_ids, input_ids)
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
||||
# ]
|
||||
|
||||
return generated_ids
|
||||
|
||||
|
||||
|
@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=4,
|
||||
default=1,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
|
||||
@ -287,7 +287,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 10000,
|
||||
"valid_interval": 5000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -408,12 +408,17 @@ def compute_loss(
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="max_length",
|
||||
padding="longest", # FIX me change padding to longest
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
|
||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == 'right':
|
||||
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
||||
else:
|
||||
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
target_ids = input_ids.clone()
|
||||
@ -423,8 +428,7 @@ def compute_loss(
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
||||
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
@ -526,7 +530,7 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["acc"] = acc
|
||||
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -743,22 +747,24 @@ def run(rank, world_size, args):
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'left'
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'right'
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
# tokenizer.padding_side = 'left'
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
@ -766,7 +772,7 @@ def run(rank, world_size, args):
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
@ -774,6 +780,10 @@ def run(rank, world_size, args):
|
||||
encoder_projector,
|
||||
)
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user