This commit is contained in:
Yuekai Zhang 2024-06-04 18:46:16 +08:00
parent e495c9d732
commit b5a906cbbd
4 changed files with 133 additions and 75 deletions

View File

@ -1,15 +1,14 @@
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html # pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install -r whisper/requirements.txt # pip install -r whisper/requirements.txt
export CUDA_VISIBLE_DEVICES=0,1
method=mask_predict torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
# method=cif_ar_distill_embedding --max-duration 80 \
torchrun --nproc_per_node 8 ./parawhisper/train.py \ --exp-dir ./whisper_llm_zh/exp_test \
--max-duration 200 \ --speech-encoder-path-or-name tiny \
--exp-dir parawhisper/exp_large_v2_${method} \ --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
--model-name large-v2 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--method $method \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn False

View File

@ -1,7 +1,8 @@
from torch import nn from torch import nn
import torch import torch
from transformers.trainer_pt_utils import LabelSmoother
DEFAULT_SPEECH_TOKEN = -1997 # "<speech>" IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module): class EncoderProjector(nn.Module):
@ -18,7 +19,6 @@ class EncoderProjector(nn.Module):
return x return x
class SPEECH_LLM(nn.Module): class SPEECH_LLM(nn.Module):
# https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/slam_model.py
def __init__( def __init__(
self, self,
encoder: nn.Module, encoder: nn.Module,
@ -28,8 +28,12 @@ class SPEECH_LLM(nn.Module):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
for name, param in encoder.named_parameters():
param.requires_grad = False
self.encoder.eval() self.encoder.eval()
self.llm = llm self.llm = llm
for name, param in llm.named_parameters():
param.requires_grad = False
self.llm.eval() self.llm.eval()
self.encoder_projector = encoder_projector self.encoder_projector = encoder_projector
self.encoder_outputs_downsample_rate = 4 self.encoder_outputs_downsample_rate = 4
@ -39,11 +43,11 @@ class SPEECH_LLM(nn.Module):
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)) left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
# 1. Create a mask to know where special speech tokens are # 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == DEFAULT_SPEECH_TOKEN special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension # Compute the maximum embed dimension
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
batch_indices, non_speech_indices = torch.where(input_ids != DEFAULT_SPEECH_TOKEN) batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
# 2. Compute the positions where text should be written # 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence. # Calculate new positions for text tokens in merged speech-text sequence.
@ -65,7 +69,7 @@ class SPEECH_LLM(nn.Module):
) )
if labels is not None: if labels is not None:
final_labels = torch.full( final_labels = torch.full(
(batch_size, max_embed_dim), self., dtype=input_ids.dtype, device=input_ids.device (batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
) )
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device. # set the corresponding tensors into their correct target device.
@ -128,17 +132,17 @@ class SPEECH_LLM(nn.Module):
speech_features, inputs_embeds, input_ids, attention_mask, labels speech_features, inputs_embeds, input_ids, attention_mask, labels
) )
outputs = self.language_model( # outputs = self.llm(
attention_mask=attention_mask, # attention_mask=attention_mask,
position_ids=position_ids, # position_ids=position_ids,
past_key_values=past_key_values, # past_key_values=past_key_values,
inputs_embeds=inputs_embeds, # inputs_embeds=inputs_embeds,
use_cache=use_cache, # use_cache=use_cache,
output_attentions=output_attentions, # output_attentions=output_attentions,
output_hidden_states=output_hidden_states, # output_hidden_states=output_hidden_states,
return_dict=return_dict, # return_dict=return_dict,
) # )
logits = outputs[0] # logits = outputs[0]
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)

View File

@ -255,4 +255,16 @@ class MultiDataset:
self.fbank_dir / "aishell_cuts_train.jsonl.gz" self.fbank_dir / "aishell_cuts_train.jsonl.gz"
) )
return aishell_cuts return aishell_cuts
def aishell_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
return aishell_dev_cuts

View File

@ -39,13 +39,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed import deepspeed
import k2 import k2
import optim # import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from model import SPEECH_LLM, EncoderProjector from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
@ -59,7 +59,7 @@ from torch.cuda.amp import GradScaler
from torch.nn.functional import pad as pad_tensor from torch.nn.functional import pad as pad_tensor
# from torch.nn.parallel import DistributedDataParallel as DDP # from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics from icefall import diagnostics
@ -78,13 +78,12 @@ from icefall.utils import (
) )
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] import transformers
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules(): for module in model.modules():
if hasattr(module, "batch_count"): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
@ -240,6 +239,13 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser = deepspeed.add_config_arguments(parser) parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser) add_model_arguments(parser)
@ -272,6 +278,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"allowed_excess_duration_ratio": 0.1, "allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10, "frame_shift_ms": 10,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
@ -357,7 +364,7 @@ def get_params() -> AttributeDict:
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
model: Union[nn.Module, DDP], model: nn.Module,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
@ -397,7 +404,6 @@ def compute_loss(
texts.append( texts.append(
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
msg, msg,
chat_template=TEMPLATE,
tokenize=True, tokenize=True,
add_generation_prompt=False, add_generation_prompt=False,
padding="max_length", padding="max_length",
@ -405,8 +411,9 @@ def compute_loss(
truncation=True, truncation=True,
) )
) )
# model_inputs = tokenizer([text], return_tensors="pt").to(device)
input_ids = torch.tensor(texts, dtype=torch.int) input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone() target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id) attention_mask = input_ids.ne(tokenizer.pad_token_id)
@ -453,47 +460,49 @@ def compute_loss(
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames) batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T) feature = feature.transpose(1, 2) # (N, C, T)
# feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts # remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts] texts = [normalize_text_alimeeting(text) for text in texts]
# messages = [ messages = []
# {"role": "system", "content": "You are a helpful assistant."}, for i, text in enumerate(texts):
# {"role": "user", "content": prompt} message = [
# ] {"role": "system", "content": "你是一个能处理音频的助手。"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text},
]
messages.append(message)
input_ids, attention_mask, target_ids = preprocess( input_ids, attention_mask, target_ids = preprocess(
texts, tokenizer, max_len=512 messages, tokenizer, max_len=128
) )
# decoder_criterion = LabelSmoothingLoss( target_ids = target_ids.type(torch.LongTensor)
# ignore_index=50256, label_smoothing=0.1, reduction="sum" input_ids = input_ids.type(torch.LongTensor)
# )
# # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> with torch.set_grad_enabled(is_training):
# ignore_prefix_size = 3 model_outpus = model(
# with torch.set_grad_enabled(is_training): fbank=feature,
# encoder_out = model.encoder(feature) input_ids=input_ids.to(device),
# text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) attention_mask=attention_mask.to(device),
# text_logits = text_logits[:, ignore_prefix_size:, :] labels=target_ids.to(device),
# target_tokens = target_tokens[:, ignore_prefix_size:] )
# loss = decoder_criterion(text_logits, target_tokens.to(device)) loss = model_outpus.loss
assert loss.requires_grad == is_training
# assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
feature_lens = supervisions["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
@ -505,7 +514,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer, tokenizer: whisper.tokenizer.Tokenizer,
model: Union[nn.Module, DDP], model: nn.Module,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -540,9 +549,9 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
model: Union[nn.Module, DDP], model: nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
@ -609,18 +618,26 @@ def train_one_epoch(
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={"sampler": train_dl.sampler.state_dict()}, client_state={},
exclude_frozen_parameters=True
) )
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
) )
os.system( os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -698,19 +715,32 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
if 'whisper' in params.speech_encoder_path_or_name: # if 'whisper' in params.speech_encoder_path_or_name:
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
# TODO: directly loading from whisper-ft checkpoint # TODO: directly loading from whisper-ft checkpoint
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt # whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
speech_encoder = whisper.load_model(params.model_name, "cpu").encoder whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder_dim = speech_encoder.dims.n_audio_ctx speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "eager"
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name, params.llm_path_or_name,
attn_implemented="flash_attention_2", attn_implementation=attn_implementation,
device_map="cpu"
) )
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
tokenizer.padding_side = 'left'
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
@ -723,6 +753,11 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
logging.info("Trainable parameters (excluding model.eval modules):")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f"{name}: {param.shape}")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
else: else:
@ -770,12 +805,14 @@ def run(rank, world_size, args):
# sampler_state_dict = checkpoints["sampler"] # sampler_state_dict = checkpoints["sampler"]
# else: # else:
# sampler_state_dict = None # sampler_state_dict = None
sampler_state_dict = None
# TODO: load sampler state dict # TODO: load sampler state dict
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = multi_dataset.dev_cuts() # valid_cuts = multi_dataset.dev_cuts()
valid_cuts = multi_dataset.aishell_dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts) valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
@ -818,14 +855,20 @@ def run(rank, world_size, args):
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}", tag=f"epoch-{params.cur_epoch}",
client_state={"sampler": train_dl.sampler.state_dict()}, client_state={},
exclude_frozen_parameters=True
) )
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}", tag=f"epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
) )
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")