This commit is contained in:
Yuekai Zhang 2024-06-04 18:46:16 +08:00
parent e495c9d732
commit b5a906cbbd
4 changed files with 133 additions and 75 deletions

View File

@ -1,15 +1,14 @@
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install -r whisper/requirements.txt
method=mask_predict
# method=cif_ar_distill_embedding
torchrun --nproc_per_node 8 ./parawhisper/train.py \
--max-duration 200 \
--exp-dir parawhisper/exp_large_v2_${method} \
--model-name large-v2 \
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
--max-duration 80 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name tiny \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
--manifest-dir data/fbank \
--method $method \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn False

View File

@ -1,7 +1,8 @@
from torch import nn
import torch
from transformers.trainer_pt_utils import LabelSmoother
DEFAULT_SPEECH_TOKEN = -1997 # "<speech>"
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
@ -18,7 +19,6 @@ class EncoderProjector(nn.Module):
return x
class SPEECH_LLM(nn.Module):
# https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/slam_model.py
def __init__(
self,
encoder: nn.Module,
@ -28,8 +28,12 @@ class SPEECH_LLM(nn.Module):
super().__init__()
self.encoder = encoder
for name, param in encoder.named_parameters():
param.requires_grad = False
self.encoder.eval()
self.llm = llm
for name, param in llm.named_parameters():
param.requires_grad = False
self.llm.eval()
self.encoder_projector = encoder_projector
self.encoder_outputs_downsample_rate = 4
@ -39,11 +43,11 @@ class SPEECH_LLM(nn.Module):
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == DEFAULT_SPEECH_TOKEN
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
batch_indices, non_speech_indices = torch.where(input_ids != DEFAULT_SPEECH_TOKEN)
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
@ -65,7 +69,7 @@ class SPEECH_LLM(nn.Module):
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self., dtype=input_ids.dtype, device=input_ids.device
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
@ -128,17 +132,17 @@ class SPEECH_LLM(nn.Module):
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs[0]
# outputs = self.llm(
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_values=past_key_values,
# inputs_embeds=inputs_embeds,
# use_cache=use_cache,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
# )
# logits = outputs[0]
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)

View File

@ -256,3 +256,15 @@ class MultiDataset:
)
return aishell_cuts
def aishell_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
return aishell_dev_cuts

View File

@ -39,13 +39,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import k2
import optim
# import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import whisper
from asr_datamodule import AsrDataModule
from model import SPEECH_LLM, EncoderProjector
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
@ -59,7 +59,7 @@ from torch.cuda.amp import GradScaler
from torch.nn.functional import pad as pad_tensor
# from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
@ -78,13 +78,12 @@ from icefall.utils import (
)
from transformers import AutoModelForCausalLM, AutoTokenizer
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
import transformers
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
@ -240,6 +239,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
@ -272,6 +278,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
@ -357,7 +364,7 @@ def get_params() -> AttributeDict:
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: Union[nn.Module, DDP],
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@ -397,7 +404,6 @@ def compute_loss(
texts.append(
tokenizer.apply_chat_template(
msg,
chat_template=TEMPLATE,
tokenize=True,
add_generation_prompt=False,
padding="max_length",
@ -405,8 +411,9 @@ def compute_loss(
truncation=True,
)
)
# model_inputs = tokenizer([text], return_tensors="pt").to(device)
input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
@ -453,47 +460,49 @@ def compute_loss(
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
# feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
# messages = [
# {"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": prompt}
# ]
messages = []
for i, text in enumerate(texts):
message = [
{"role": "system", "content": "你是一个能处理音频的助手。"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text},
]
messages.append(message)
input_ids, attention_mask, target_ids = preprocess(
texts, tokenizer, max_len=512
messages, tokenizer, max_len=128
)
# decoder_criterion = LabelSmoothingLoss(
# ignore_index=50256, label_smoothing=0.1, reduction="sum"
# )
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
# # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
# ignore_prefix_size = 3
# with torch.set_grad_enabled(is_training):
# encoder_out = model.encoder(feature)
# text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
# text_logits = text_logits[:, ignore_prefix_size:, :]
# target_tokens = target_tokens[:, ignore_prefix_size:]
# loss = decoder_criterion(text_logits, target_tokens.to(device))
# assert loss.requires_grad == is_training
with torch.set_grad_enabled(is_training):
model_outpus = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
loss = model_outpus.loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = supervisions["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
@ -505,7 +514,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
model: Union[nn.Module, DDP],
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -540,9 +549,9 @@ def compute_validation_loss(
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: Union[nn.Module, DDP],
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
@ -609,18 +618,26 @@ def train_one_epoch(
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={"sampler": train_dl.sampler.state_dict()},
client_state={},
exclude_frozen_parameters=True
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
@ -698,19 +715,32 @@ def run(rank, world_size, args):
logging.info("About to create model")
if 'whisper' in params.speech_encoder_path_or_name:
# if 'whisper' in params.speech_encoder_path_or_name:
replace_whisper_encoder_forward()
# TODO: directly loading from whisper-ft checkpoint
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
speech_encoder = whisper.load_model(params.model_name, "cpu").encoder
speech_encoder_dim = speech_encoder.dims.n_audio_ctx
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "eager"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implemented="flash_attention_2",
device_map="cpu"
attn_implementation=attn_implementation,
)
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
tokenizer.padding_side = 'left'
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
@ -723,6 +753,11 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
logging.info("Trainable parameters (excluding model.eval modules):")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f"{name}: {param.shape}")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
@ -770,12 +805,14 @@ def run(rank, world_size, args):
# sampler_state_dict = checkpoints["sampler"]
# else:
# sampler_state_dict = None
sampler_state_dict = None
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = multi_dataset.dev_cuts()
# valid_cuts = multi_dataset.dev_cuts()
valid_cuts = multi_dataset.aishell_dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0:
@ -818,14 +855,20 @@ def run(rank, world_size, args):
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={"sampler": train_dl.sampler.state_dict()},
client_state={},
exclude_frozen_parameters=True
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")