mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
refactor decode part
This commit is contained in:
parent
71a0a442a6
commit
d742043e75
@ -60,7 +60,7 @@ from data_module import AsrDataModule
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from model import SPEECH_LLM, EncoderProjector
|
from model import SPEECH_LLM, EncoderProjector
|
||||||
|
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
from train import DEFAULT_SPEECH_TOKEN
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
@ -70,10 +70,164 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
|
average_checkpoints,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_model(params, device):
|
||||||
|
"""Load and prepare the speech-to-speech model."""
|
||||||
|
if params.remove_whisper_encoder_input_length_restriction:
|
||||||
|
replace_whisper_encoder_forward()
|
||||||
|
|
||||||
|
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||||
|
speech_encoder = whisper_model.encoder
|
||||||
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
|
|
||||||
|
if params.use_flash_attn:
|
||||||
|
attn_implementation = "flash_attention_2"
|
||||||
|
# torch_dtype=torch.bfloat16 FIX ME
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_implementation = "eager"
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
tokenizer.padding_side = "right"
|
||||||
|
|
||||||
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
params.llm_path_or_name,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
if params.use_lora:
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=64,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"up_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
llm = get_peft_model(llm, lora_config)
|
||||||
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
|
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||||
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||||
|
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||||
|
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
|
|
||||||
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
|
DEFAULT_SPEECH_TOKEN
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_projector = EncoderProjector(
|
||||||
|
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.enable_speech_output:
|
||||||
|
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||||
|
if params.use_flash_attn:
|
||||||
|
attn_implementation = "flash_attention_2"
|
||||||
|
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||||
|
else:
|
||||||
|
attn_implementation = "eager"
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
# params.llm_path_or_name,
|
||||||
|
# attn_implementation=attn_implementation,
|
||||||
|
# torch_dtype=torch_dtype,
|
||||||
|
# )
|
||||||
|
codec_vocab_size = 4096 + 4
|
||||||
|
config = Qwen2Config(
|
||||||
|
vocab_size=codec_vocab_size,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
intermediate_size=2048,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
)
|
||||||
|
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||||
|
# Pass attn_implementation and torch_dtype to the constructor
|
||||||
|
# Use AutoModelForCausalLM.from_config for more generality
|
||||||
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
|
config=config,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
# cosyvoice2_token_size = 6561
|
||||||
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
|
codec_lm.vocab_size = codec_vocab_size
|
||||||
|
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||||
|
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||||
|
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||||
|
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||||
|
# if params.use_lora:
|
||||||
|
# lora_config = LoraConfig(
|
||||||
|
# r=64,
|
||||||
|
# lora_alpha=16,
|
||||||
|
# target_modules=[
|
||||||
|
# "q_proj",
|
||||||
|
# "k_proj",
|
||||||
|
# "v_proj",
|
||||||
|
# "o_proj",
|
||||||
|
# "up_proj",
|
||||||
|
# "gate_proj",
|
||||||
|
# "down_proj",
|
||||||
|
# ],
|
||||||
|
# lora_dropout=0.05,
|
||||||
|
# task_type="CAUSAL_LM",
|
||||||
|
# )
|
||||||
|
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||||
|
# codec_lm.print_trainable_parameters()
|
||||||
|
else:
|
||||||
|
codec_lm = None
|
||||||
|
|
||||||
|
model = SPEECH_LLM(
|
||||||
|
speech_encoder,
|
||||||
|
llm,
|
||||||
|
encoder_projector,
|
||||||
|
codec_lm,
|
||||||
|
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.avg > 1:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
assert start >= 1, start
|
||||||
|
checkpoint = torch.load(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
|
)
|
||||||
|
assert "model" not in checkpoint
|
||||||
|
# deepspeed converted checkpoint only contains model state_dict
|
||||||
|
filenames = [
|
||||||
|
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||||
|
for epoch in range(start, params.epoch + 1)
|
||||||
|
]
|
||||||
|
avg_checkpoint = average_checkpoints(filenames)
|
||||||
|
model.load_state_dict(avg_checkpoint, strict=False)
|
||||||
|
|
||||||
|
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
torch.save(avg_checkpoint, filename)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
|
)
|
||||||
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def average_checkpoints(
|
def average_checkpoints(
|
||||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -171,13 +325,6 @@ def get_parser():
|
|||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--remove-whisper-encoder-input-length-restriction",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="replace whisper encoder forward method to remove input length restriction",
|
|
||||||
)
|
|
||||||
|
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--dataset",
|
# "--dataset",
|
||||||
# type=str,
|
# type=str,
|
||||||
@ -321,7 +468,7 @@ def decode_one_batch(
|
|||||||
with open(speech_token_file_name, 'w') as f:
|
with open(speech_token_file_name, 'w') as f:
|
||||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||||
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
# print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||||
f.write(f"{cut_id}|{save_str}\n")
|
f.write(f"{cut_id}|{save_str}\n")
|
||||||
|
|
||||||
@ -509,155 +656,8 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
if params.remove_whisper_encoder_input_length_restriction:
|
model, tokenizer = get_model(params, device)
|
||||||
replace_whisper_encoder_forward()
|
|
||||||
|
|
||||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
|
||||||
speech_encoder = whisper_model.encoder
|
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
|
||||||
|
|
||||||
if params.use_flash_attn:
|
|
||||||
attn_implementation = "flash_attention_2"
|
|
||||||
# torch_dtype=torch.bfloat16 FIX ME
|
|
||||||
torch_dtype = torch.float16
|
|
||||||
tokenizer.padding_side = "left"
|
|
||||||
|
|
||||||
else:
|
|
||||||
attn_implementation = "eager"
|
|
||||||
torch_dtype = torch.float16
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
params.llm_path_or_name,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
if params.use_lora:
|
|
||||||
lora_config = LoraConfig(
|
|
||||||
r=64,
|
|
||||||
lora_alpha=16,
|
|
||||||
target_modules=[
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"up_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"down_proj",
|
|
||||||
],
|
|
||||||
task_type="CAUSAL_LM",
|
|
||||||
)
|
|
||||||
llm = get_peft_model(llm, lora_config)
|
|
||||||
llm.print_trainable_parameters()
|
|
||||||
|
|
||||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
|
||||||
tokenizer.add_special_tokens(special_tokens_dict)
|
|
||||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
|
||||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
|
||||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
||||||
|
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
|
||||||
DEFAULT_SPEECH_TOKEN
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_projector = EncoderProjector(
|
|
||||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.enable_speech_output:
|
|
||||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
|
||||||
if params.use_flash_attn:
|
|
||||||
attn_implementation = "flash_attention_2"
|
|
||||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
|
||||||
else:
|
|
||||||
attn_implementation = "eager"
|
|
||||||
torch_dtype = torch.float16
|
|
||||||
|
|
||||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
# params.llm_path_or_name,
|
|
||||||
# attn_implementation=attn_implementation,
|
|
||||||
# torch_dtype=torch_dtype,
|
|
||||||
# )
|
|
||||||
codec_vocab_size = 4096 + 4
|
|
||||||
config = Qwen2Config(
|
|
||||||
vocab_size=codec_vocab_size,
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=12,
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_key_value_heads=16,
|
|
||||||
intermediate_size=2048,
|
|
||||||
max_position_embeddings=4096,
|
|
||||||
)
|
|
||||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
|
||||||
# Pass attn_implementation and torch_dtype to the constructor
|
|
||||||
# Use AutoModelForCausalLM.from_config for more generality
|
|
||||||
codec_lm = AutoModelForCausalLM.from_config(
|
|
||||||
config=config,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
torch_dtype=torch_dtype
|
|
||||||
)
|
|
||||||
# cosyvoice2_token_size = 6561
|
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
|
||||||
codec_lm.vocab_size = codec_vocab_size
|
|
||||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
|
||||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
|
||||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
|
||||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
|
||||||
# if params.use_lora:
|
|
||||||
# lora_config = LoraConfig(
|
|
||||||
# r=64,
|
|
||||||
# lora_alpha=16,
|
|
||||||
# target_modules=[
|
|
||||||
# "q_proj",
|
|
||||||
# "k_proj",
|
|
||||||
# "v_proj",
|
|
||||||
# "o_proj",
|
|
||||||
# "up_proj",
|
|
||||||
# "gate_proj",
|
|
||||||
# "down_proj",
|
|
||||||
# ],
|
|
||||||
# lora_dropout=0.05,
|
|
||||||
# task_type="CAUSAL_LM",
|
|
||||||
# )
|
|
||||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
|
||||||
# codec_lm.print_trainable_parameters()
|
|
||||||
else:
|
|
||||||
codec_lm = None
|
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
|
||||||
speech_encoder,
|
|
||||||
llm,
|
|
||||||
encoder_projector,
|
|
||||||
codec_lm,
|
|
||||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.avg > 1:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
assert start >= 1, start
|
|
||||||
checkpoint = torch.load(
|
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
|
||||||
)
|
|
||||||
assert "model" not in checkpoint
|
|
||||||
# deepspeed converted checkpoint only contains model state_dict
|
|
||||||
filenames = [
|
|
||||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
|
||||||
for epoch in range(start, params.epoch + 1)
|
|
||||||
]
|
|
||||||
avg_checkpoint = average_checkpoints(filenames)
|
|
||||||
model.load_state_dict(avg_checkpoint, strict=False)
|
|
||||||
|
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
|
||||||
torch.save(avg_checkpoint, filename)
|
|
||||||
else:
|
|
||||||
checkpoint = torch.load(
|
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user