mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
refactor decode part
This commit is contained in:
parent
71a0a442a6
commit
d742043e75
@ -60,7 +60,7 @@ from data_module import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
@ -70,10 +70,164 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
average_checkpoints,
|
||||
)
|
||||
|
||||
def get_model(params, device):
|
||||
"""Load and prepare the speech-to-speech model."""
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
if params.enable_speech_output:
|
||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||
# params.llm_path_or_name,
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 4096 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# Pass attn_implementation and torch_dtype to the constructor
|
||||
# Use AutoModelForCausalLM.from_config for more generality
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
# if params.use_lora:
|
||||
# lora_config = LoraConfig(
|
||||
# r=64,
|
||||
# lora_alpha=16,
|
||||
# target_modules=[
|
||||
# "q_proj",
|
||||
# "k_proj",
|
||||
# "v_proj",
|
||||
# "o_proj",
|
||||
# "up_proj",
|
||||
# "gate_proj",
|
||||
# "down_proj",
|
||||
# ],
|
||||
# lora_dropout=0.05,
|
||||
# task_type="CAUSAL_LM",
|
||||
# )
|
||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
# codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
||||
)
|
||||
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
assert "model" not in checkpoint
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
avg_checkpoint = average_checkpoints(filenames)
|
||||
model.load_state_dict(avg_checkpoint, strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
@ -171,13 +325,6 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--dataset",
|
||||
# type=str,
|
||||
@ -321,7 +468,7 @@ def decode_one_batch(
|
||||
with open(speech_token_file_name, 'w') as f:
|
||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||
f.write(f"{cut_id}|{save_str}\n")
|
||||
|
||||
@ -509,155 +656,8 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
model, tokenizer = get_model(params, device)
|
||||
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
if params.enable_speech_output:
|
||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||
# params.llm_path_or_name,
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 4096 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# Pass attn_implementation and torch_dtype to the constructor
|
||||
# Use AutoModelForCausalLM.from_config for more generality
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
# if params.use_lora:
|
||||
# lora_config = LoraConfig(
|
||||
# r=64,
|
||||
# lora_alpha=16,
|
||||
# target_modules=[
|
||||
# "q_proj",
|
||||
# "k_proj",
|
||||
# "v_proj",
|
||||
# "o_proj",
|
||||
# "up_proj",
|
||||
# "gate_proj",
|
||||
# "down_proj",
|
||||
# ],
|
||||
# lora_dropout=0.05,
|
||||
# task_type="CAUSAL_LM",
|
||||
# )
|
||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
# codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
||||
)
|
||||
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
assert "model" not in checkpoint
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
avg_checkpoint = average_checkpoints(filenames)
|
||||
model.load_state_dict(avg_checkpoint, strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user