mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
fix bugs
This commit is contained in:
parent
e495c9d732
commit
b5a906cbbd
@ -1,15 +1,14 @@
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall
|
||||
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
||||
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install -r whisper/requirements.txt
|
||||
|
||||
method=mask_predict
|
||||
# method=cif_ar_distill_embedding
|
||||
torchrun --nproc_per_node 8 ./parawhisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir parawhisper/exp_large_v2_${method} \
|
||||
--model-name large-v2 \
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name tiny \
|
||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||
--manifest-dir data/fbank \
|
||||
--method $method \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn False
|
@ -1,7 +1,8 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
DEFAULT_SPEECH_TOKEN = -1997 # "<speech>"
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
|
||||
@ -18,7 +19,6 @@ class EncoderProjector(nn.Module):
|
||||
return x
|
||||
|
||||
class SPEECH_LLM(nn.Module):
|
||||
# https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/slam_model.py
|
||||
def __init__(
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
@ -28,8 +28,12 @@ class SPEECH_LLM(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = encoder
|
||||
for name, param in encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.encoder.eval()
|
||||
self.llm = llm
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.llm.eval()
|
||||
self.encoder_projector = encoder_projector
|
||||
self.encoder_outputs_downsample_rate = 4
|
||||
@ -39,11 +43,11 @@ class SPEECH_LLM(nn.Module):
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
||||
# 1. Create a mask to know where special speech tokens are
|
||||
special_speech_token_mask = input_ids == DEFAULT_SPEECH_TOKEN
|
||||
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
||||
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
|
||||
batch_indices, non_speech_indices = torch.where(input_ids != DEFAULT_SPEECH_TOKEN)
|
||||
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged speech-text sequence.
|
||||
@ -65,7 +69,7 @@ class SPEECH_LLM(nn.Module):
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self., dtype=input_ids.dtype, device=input_ids.device
|
||||
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
@ -128,17 +132,17 @@ class SPEECH_LLM(nn.Module):
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
logits = outputs[0]
|
||||
# outputs = self.llm(
|
||||
# attention_mask=attention_mask,
|
||||
# position_ids=position_ids,
|
||||
# past_key_values=past_key_values,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# use_cache=use_cache,
|
||||
# output_attentions=output_attentions,
|
||||
# output_hidden_states=output_hidden_states,
|
||||
# return_dict=return_dict,
|
||||
# )
|
||||
# logits = outputs[0]
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
|
||||
|
@ -256,3 +256,15 @@ class MultiDataset:
|
||||
)
|
||||
|
||||
return aishell_cuts
|
||||
|
||||
|
||||
def aishell_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
return aishell_dev_cuts
|
@ -39,13 +39,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
# import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
@ -59,7 +59,7 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
# from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
@ -78,13 +78,12 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
import transformers
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
@ -240,6 +239,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
add_model_arguments(parser)
|
||||
|
||||
@ -272,6 +278,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"allowed_excess_duration_ratio": 0.1,
|
||||
"subsampling_factor": 2,
|
||||
"frame_shift_ms": 10,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
@ -357,7 +364,7 @@ def get_params() -> AttributeDict:
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
@ -397,7 +404,6 @@ def compute_loss(
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
chat_template=TEMPLATE,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
padding="max_length",
|
||||
@ -405,8 +411,9 @@ def compute_loss(
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
# model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
@ -453,47 +460,49 @@ def compute_loss(
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
|
||||
# feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
supervisions = batch["supervisions"]
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||
|
||||
# messages = [
|
||||
# {"role": "system", "content": "You are a helpful assistant."},
|
||||
# {"role": "user", "content": prompt}
|
||||
# ]
|
||||
messages = []
|
||||
for i, text in enumerate(texts):
|
||||
message = [
|
||||
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": text},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
input_ids, attention_mask, target_ids = preprocess(
|
||||
texts, tokenizer, max_len=512
|
||||
messages, tokenizer, max_len=128
|
||||
)
|
||||
|
||||
# decoder_criterion = LabelSmoothingLoss(
|
||||
# ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
||||
# )
|
||||
target_ids = target_ids.type(torch.LongTensor)
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
|
||||
# # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||
# ignore_prefix_size = 3
|
||||
# with torch.set_grad_enabled(is_training):
|
||||
# encoder_out = model.encoder(feature)
|
||||
# text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||
# text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||
# target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||
# loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
|
||||
# assert loss.requires_grad == is_training
|
||||
with torch.set_grad_enabled(is_training):
|
||||
model_outpus = model(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
loss = model_outpus.loss
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
feature_lens = supervisions["num_frames"]
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
@ -505,7 +514,7 @@ def compute_loss(
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
@ -540,9 +549,9 @@ def compute_validation_loss(
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
scheduler: torch.optim.lr_scheduler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
@ -609,18 +618,26 @@ def train_one_epoch(
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={"sampler": train_dl.sampler.state_dict()},
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -698,19 +715,32 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
if 'whisper' in params.speech_encoder_path_or_name:
|
||||
replace_whisper_encoder_forward()
|
||||
# TODO: directly loading from whisper-ft checkpoint
|
||||
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
speech_encoder = whisper.load_model(params.model_name, "cpu").encoder
|
||||
speech_encoder_dim = speech_encoder.dims.n_audio_ctx
|
||||
# if 'whisper' in params.speech_encoder_path_or_name:
|
||||
replace_whisper_encoder_forward()
|
||||
# TODO: directly loading from whisper-ft checkpoint
|
||||
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implemented="flash_attention_2",
|
||||
device_map="cpu"
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
tokenizer.padding_side = 'left'
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
||||
|
||||
@ -723,6 +753,11 @@ def run(rank, world_size, args):
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
logging.info("Trainable parameters (excluding model.eval modules):")
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
logging.info(f"{name}: {param.shape}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
else:
|
||||
@ -770,12 +805,14 @@ def run(rank, world_size, args):
|
||||
# sampler_state_dict = checkpoints["sampler"]
|
||||
# else:
|
||||
# sampler_state_dict = None
|
||||
sampler_state_dict = None
|
||||
# TODO: load sampler state dict
|
||||
train_dl = data_module.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
# valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
@ -818,14 +855,20 @@ def run(rank, world_size, args):
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={"sampler": train_dl.sampler.state_dict()},
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
|
||||
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user