refactor decode part

This commit is contained in:
Yuekai Zhang 2025-04-25 18:31:43 +08:00
parent 71a0a442a6
commit d742043e75

View File

@ -60,7 +60,7 @@ from data_module import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
@ -70,10 +70,164 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
average_checkpoints,
)
def get_model(params, device):
"""Load and prepare the speech-to-speech model."""
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 4096 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else:
codec_lm = None
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
)
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
return model, tokenizer
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
@ -171,13 +325,6 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
# parser.add_argument(
# "--dataset",
# type=str,
@ -321,7 +468,7 @@ def decode_one_batch(
with open(speech_token_file_name, 'w') as f:
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000)
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
save_str = " ".join([str(i) for i in generated_speech_output])
f.write(f"{cut_id}|{save_str}\n")
@ -509,155 +656,8 @@ def main():
logging.info(f"device: {device}")
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
model, tokenizer = get_model(params, device)
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 4096 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else:
codec_lm = None
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
)
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")