mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix decoding issue and padding to longest
This commit is contained in:
parent
eb2c255e1e
commit
271536248f
@ -196,7 +196,7 @@ def get_parser():
|
|||||||
default=True,
|
default=True,
|
||||||
help="Whether to use flash attention.",
|
help="Whether to use flash attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -238,7 +238,7 @@ def decode_one_batch(
|
|||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Preprocesses the data for supervised fine-tuning."""
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
texts = []
|
texts = []
|
||||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
texts.append(
|
texts.append(
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(
|
||||||
@ -246,11 +246,16 @@ def decode_one_batch(
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
chat_template=TEMPLATE,
|
chat_template=TEMPLATE,
|
||||||
padding="max_length",
|
padding="longest",
|
||||||
max_length=max_len,
|
max_length=max_len,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
max_len_texts = max([len(text) for text in texts])
|
||||||
|
if tokenizer.padding_side == 'right':
|
||||||
|
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
||||||
|
else:
|
||||||
|
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
||||||
|
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
|
||||||
@ -481,33 +486,36 @@ def main():
|
|||||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||||
speech_encoder = whisper_model.encoder
|
speech_encoder = whisper_model.encoder
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
|
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
# torch_dtype=torch.bfloat16
|
# torch_dtype=torch.bfloat16
|
||||||
torch_dtype=torch.float16
|
torch_dtype=torch.float16
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype=torch.float16
|
torch_dtype=torch.float16
|
||||||
|
tokenizer.padding_side = 'right'
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
params.llm_path_or_name,
|
params.llm_path_or_name,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
|
||||||
# tokenizer.padding_side = 'left'
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
}
|
}
|
||||||
tokenizer.add_special_tokens(special_tokens_dict)
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||||
llm.config.bos_token_id = tokenizer.bos_token_id
|
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||||
llm.config.eos_token_id = tokenizer.eos_token_id
|
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
|
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
from icefall.dist import get_rank
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
|
||||||
class EncoderProjector(nn.Module):
|
class EncoderProjector(nn.Module):
|
||||||
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
||||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=1):
|
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.downsample_rate = downsample_rate
|
self.downsample_rate = downsample_rate
|
||||||
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
|
||||||
@ -47,7 +47,6 @@ class SPEECH_LLM(nn.Module):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self.llm.eval()
|
self.llm.eval()
|
||||||
self.encoder_projector = encoder_projector
|
self.encoder_projector = encoder_projector
|
||||||
self.encoder_outputs_downsample_rate = 4
|
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||||
@ -134,33 +133,39 @@ class SPEECH_LLM(nn.Module):
|
|||||||
labels: torch.LongTensor = None,
|
labels: torch.LongTensor = None,
|
||||||
):
|
):
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
# downsample encoder_outs by 4
|
|
||||||
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
enable_logging = False
|
||||||
|
rank = get_rank()
|
||||||
|
|
||||||
# print("input_ids", input_ids, input_ids.shape)
|
# log only on rank 0, training using deep
|
||||||
# print("labels", labels, labels.shape)
|
if enable_logging and rank == 0:
|
||||||
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
print("input_ids", input_ids, input_ids.shape)
|
||||||
# print("attention_mask_before", attention_mask.shape, attention_mask)
|
print("labels", labels, labels.shape)
|
||||||
# print(2333333333333333333333333333)
|
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||||
|
print("attention_mask_before", attention_mask.shape, attention_mask)
|
||||||
|
print(2333333333333333333333333333)
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
# print("labels", labels, labels.shape)
|
if enable_logging and rank == 0:
|
||||||
# print("speech_features", speech_features.shape, speech_features)
|
print("speech_features", speech_features.shape, speech_features)
|
||||||
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||||
# print("attention_mask", attention_mask.shape, attention_mask)
|
print("attention_mask", attention_mask.shape, attention_mask)
|
||||||
# print("position_ids", position_ids.shape, position_ids)
|
print("position_ids", position_ids.shape, position_ids)
|
||||||
# print("================================================================")
|
print("labels", labels, labels.shape)
|
||||||
# input()
|
print("================================================================")
|
||||||
|
|
||||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
||||||
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = torch.argmax(model_outputs.logits, -1)
|
preds = torch.argmax(model_outputs.logits, -1)
|
||||||
|
if enable_logging and rank == 0:
|
||||||
|
print("preds", preds, preds.shape)
|
||||||
|
print(4555555555555555555555555555555555555555555)
|
||||||
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
|
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
|
||||||
return model_outputs, acc
|
return model_outputs, acc
|
||||||
|
|
||||||
@ -173,9 +178,6 @@ class SPEECH_LLM(nn.Module):
|
|||||||
):
|
):
|
||||||
|
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
# downsample encoder_outs by 4
|
|
||||||
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
speech_features = speech_features.to(torch.float16)
|
speech_features = speech_features.to(torch.float16)
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
@ -196,10 +198,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
eos_token_id=self.llm.config.eos_token_id,
|
eos_token_id=self.llm.config.eos_token_id,
|
||||||
pad_token_id=self.llm.config.pad_token_id
|
pad_token_id=self.llm.config.pad_token_id
|
||||||
)
|
)
|
||||||
# print(generated_ids, input_ids)
|
|
||||||
# generated_ids = [
|
|
||||||
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
|
||||||
# ]
|
|
||||||
return generated_ids
|
return generated_ids
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-projector-ds-rate",
|
"--encoder-projector-ds-rate",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=1,
|
||||||
help="Downsample rate for the encoder projector.",
|
help="Downsample rate for the encoder projector.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 10000,
|
"valid_interval": 5000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -408,12 +408,17 @@ def compute_loss(
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
chat_template=TEMPLATE,
|
chat_template=TEMPLATE,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
padding="max_length",
|
padding="longest", # FIX me change padding to longest
|
||||||
max_length=max_len,
|
max_length=max_len,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||||
|
max_len_texts = max([len(text) for text in texts])
|
||||||
|
if tokenizer.padding_side == 'right':
|
||||||
|
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
||||||
|
else:
|
||||||
|
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||||
target_ids = input_ids.clone()
|
target_ids = input_ids.clone()
|
||||||
@ -423,8 +428,7 @@ def compute_loss(
|
|||||||
mask_prompt = True
|
mask_prompt = True
|
||||||
if mask_prompt:
|
if mask_prompt:
|
||||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
||||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
|
||||||
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
|
||||||
for i in range(mask_indices[0].size(0)):
|
for i in range(mask_indices[0].size(0)):
|
||||||
row = mask_indices[0][i]
|
row = mask_indices[0][i]
|
||||||
col = mask_indices[1][i]
|
col = mask_indices[1][i]
|
||||||
@ -526,7 +530,7 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["acc"] = acc
|
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -743,22 +747,24 @@ def run(rank, world_size, args):
|
|||||||
speech_encoder = whisper_model.encoder
|
speech_encoder = whisper_model.encoder
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
# torch_dtype=torch.bfloat16
|
# torch_dtype=torch.bfloat16
|
||||||
torch_dtype=torch.float16
|
torch_dtype=torch.float16
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype=torch.float16
|
torch_dtype=torch.float16
|
||||||
|
tokenizer.padding_side = 'right'
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
params.llm_path_or_name,
|
params.llm_path_or_name,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
|
||||||
# tokenizer.padding_side = 'left'
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
}
|
}
|
||||||
@ -766,7 +772,7 @@ def run(rank, world_size, args):
|
|||||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
@ -774,6 +780,10 @@ def run(rank, world_size, args):
|
|||||||
encoder_projector,
|
encoder_projector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.pretrained_model_path:
|
||||||
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user