mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
fix bugs
This commit is contained in:
parent
e495c9d732
commit
b5a906cbbd
@ -1,15 +1,14 @@
|
|||||||
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall
|
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
||||||
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
# pip install -r whisper/requirements.txt
|
# pip install -r whisper/requirements.txt
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1
|
||||||
method=mask_predict
|
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
||||||
# method=cif_ar_distill_embedding
|
--max-duration 80 \
|
||||||
torchrun --nproc_per_node 8 ./parawhisper/train.py \
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
--max-duration 200 \
|
--speech-encoder-path-or-name tiny \
|
||||||
--exp-dir parawhisper/exp_large_v2_${method} \
|
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
--model-name large-v2 \
|
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--method $method \
|
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn False
|
@ -1,7 +1,8 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
|
||||||
DEFAULT_SPEECH_TOKEN = -1997 # "<speech>"
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
|
||||||
class EncoderProjector(nn.Module):
|
class EncoderProjector(nn.Module):
|
||||||
|
|
||||||
@ -18,7 +19,6 @@ class EncoderProjector(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class SPEECH_LLM(nn.Module):
|
class SPEECH_LLM(nn.Module):
|
||||||
# https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/slam_model.py
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: nn.Module,
|
encoder: nn.Module,
|
||||||
@ -28,8 +28,12 @@ class SPEECH_LLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
for name, param in encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
self.encoder.eval()
|
self.encoder.eval()
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
for name, param in llm.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
self.llm.eval()
|
self.llm.eval()
|
||||||
self.encoder_projector = encoder_projector
|
self.encoder_projector = encoder_projector
|
||||||
self.encoder_outputs_downsample_rate = 4
|
self.encoder_outputs_downsample_rate = 4
|
||||||
@ -39,11 +43,11 @@ class SPEECH_LLM(nn.Module):
|
|||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
||||||
# 1. Create a mask to know where special speech tokens are
|
# 1. Create a mask to know where special speech tokens are
|
||||||
special_speech_token_mask = input_ids == DEFAULT_SPEECH_TOKEN
|
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
||||||
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
||||||
# Compute the maximum embed dimension
|
# Compute the maximum embed dimension
|
||||||
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
|
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
|
||||||
batch_indices, non_speech_indices = torch.where(input_ids != DEFAULT_SPEECH_TOKEN)
|
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
|
||||||
|
|
||||||
# 2. Compute the positions where text should be written
|
# 2. Compute the positions where text should be written
|
||||||
# Calculate new positions for text tokens in merged speech-text sequence.
|
# Calculate new positions for text tokens in merged speech-text sequence.
|
||||||
@ -65,7 +69,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
)
|
)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
final_labels = torch.full(
|
final_labels = torch.full(
|
||||||
(batch_size, max_embed_dim), self., dtype=input_ids.dtype, device=input_ids.device
|
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
|
||||||
)
|
)
|
||||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||||
# set the corresponding tensors into their correct target device.
|
# set the corresponding tensors into their correct target device.
|
||||||
@ -128,17 +132,17 @@ class SPEECH_LLM(nn.Module):
|
|||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.language_model(
|
# outputs = self.llm(
|
||||||
attention_mask=attention_mask,
|
# attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
# position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
# past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
# inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
# use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
# output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
# output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
# return_dict=return_dict,
|
||||||
)
|
# )
|
||||||
logits = outputs[0]
|
# logits = outputs[0]
|
||||||
|
|
||||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||||
|
|
||||||
|
@ -255,4 +255,16 @@ class MultiDataset:
|
|||||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
return aishell_cuts
|
return aishell_cuts
|
||||||
|
|
||||||
|
|
||||||
|
def aishell_dev_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get multidataset dev cuts")
|
||||||
|
|
||||||
|
# AISHELL
|
||||||
|
logging.info("Loading Aishell set in lazy mode")
|
||||||
|
aishell_dev_cuts = load_manifest_lazy(
|
||||||
|
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
return aishell_dev_cuts
|
@ -39,13 +39,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import k2
|
import k2
|
||||||
import optim
|
# import optim
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import whisper
|
import whisper
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from model import SPEECH_LLM, EncoderProjector
|
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
@ -59,7 +59,7 @@ from torch.cuda.amp import GradScaler
|
|||||||
from torch.nn.functional import pad as pad_tensor
|
from torch.nn.functional import pad as pad_tensor
|
||||||
# from torch.nn.parallel import DistributedDataParallel as DDP
|
# from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
@ -78,13 +78,12 @@ from icefall.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
import transformers
|
||||||
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
|
||||||
|
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
# get underlying nn.Module
|
|
||||||
model = model.module
|
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, "batch_count"):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
@ -240,6 +239,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-flash-attn",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use flash attention.",
|
||||||
|
)
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
@ -272,6 +278,7 @@ def get_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"allowed_excess_duration_ratio": 0.1,
|
"allowed_excess_duration_ratio": 0.1,
|
||||||
|
"subsampling_factor": 2,
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
@ -357,7 +364,7 @@ def get_params() -> AttributeDict:
|
|||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
model: Union[nn.Module, DDP],
|
model: nn.Module,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
@ -397,7 +404,6 @@ def compute_loss(
|
|||||||
texts.append(
|
texts.append(
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(
|
||||||
msg,
|
msg,
|
||||||
chat_template=TEMPLATE,
|
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
@ -405,8 +411,9 @@ def compute_loss(
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||||
target_ids = input_ids.clone()
|
target_ids = input_ids.clone()
|
||||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
@ -453,47 +460,49 @@ def compute_loss(
|
|||||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||||
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
|
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
feature = feature.transpose(1, 2) # (N, C, T)
|
||||||
|
|
||||||
# feature_lens = supervisions["num_frames"].to(device)
|
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
# remove spaces in texts
|
# remove spaces in texts
|
||||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||||
|
|
||||||
# messages = [
|
messages = []
|
||||||
# {"role": "system", "content": "You are a helpful assistant."},
|
for i, text in enumerate(texts):
|
||||||
# {"role": "user", "content": prompt}
|
message = [
|
||||||
# ]
|
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||||
|
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
{"role": "assistant", "content": text},
|
||||||
|
]
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
input_ids, attention_mask, target_ids = preprocess(
|
input_ids, attention_mask, target_ids = preprocess(
|
||||||
texts, tokenizer, max_len=512
|
messages, tokenizer, max_len=128
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder_criterion = LabelSmoothingLoss(
|
target_ids = target_ids.type(torch.LongTensor)
|
||||||
# ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
input_ids = input_ids.type(torch.LongTensor)
|
||||||
# )
|
|
||||||
|
|
||||||
# # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
with torch.set_grad_enabled(is_training):
|
||||||
# ignore_prefix_size = 3
|
model_outpus = model(
|
||||||
# with torch.set_grad_enabled(is_training):
|
fbank=feature,
|
||||||
# encoder_out = model.encoder(feature)
|
input_ids=input_ids.to(device),
|
||||||
# text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
attention_mask=attention_mask.to(device),
|
||||||
# text_logits = text_logits[:, ignore_prefix_size:, :]
|
labels=target_ids.to(device),
|
||||||
# target_tokens = target_tokens[:, ignore_prefix_size:]
|
)
|
||||||
# loss = decoder_criterion(text_logits, target_tokens.to(device))
|
loss = model_outpus.loss
|
||||||
|
assert loss.requires_grad == is_training
|
||||||
# assert loss.requires_grad == is_training
|
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
feature_lens = supervisions["num_frames"]
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
@ -505,7 +514,7 @@ def compute_loss(
|
|||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: whisper.tokenizer.Tokenizer,
|
tokenizer: whisper.tokenizer.Tokenizer,
|
||||||
model: Union[nn.Module, DDP],
|
model: nn.Module,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> MetricsTracker:
|
) -> MetricsTracker:
|
||||||
@ -540,9 +549,9 @@ def compute_validation_loss(
|
|||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
model: Union[nn.Module, DDP],
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: torch.optim.lr_scheduler,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
@ -609,18 +618,26 @@ def train_one_epoch(
|
|||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
client_state={"sampler": train_dl.sampler.state_dict()},
|
client_state={},
|
||||||
|
exclude_frozen_parameters=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
|
exclude_frozen_parameters=True,
|
||||||
|
)
|
||||||
|
# save sampler state dict into checkpoint
|
||||||
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
|
torch.save(
|
||||||
|
sampler_state_dict,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||||
)
|
)
|
||||||
os.system(
|
os.system(
|
||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -698,19 +715,32 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
if 'whisper' in params.speech_encoder_path_or_name:
|
# if 'whisper' in params.speech_encoder_path_or_name:
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
# TODO: directly loading from whisper-ft checkpoint
|
# TODO: directly loading from whisper-ft checkpoint
|
||||||
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
speech_encoder = whisper.load_model(params.model_name, "cpu").encoder
|
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||||
speech_encoder_dim = speech_encoder.dims.n_audio_ctx
|
speech_encoder = whisper_model.encoder
|
||||||
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
|
|
||||||
|
if params.use_flash_attn:
|
||||||
|
attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_implementation = "eager"
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
params.llm_path_or_name,
|
params.llm_path_or_name,
|
||||||
attn_implemented="flash_attention_2",
|
attn_implementation=attn_implementation,
|
||||||
device_map="cpu"
|
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
special_tokens_dict = {
|
||||||
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
|
}
|
||||||
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
||||||
|
|
||||||
@ -723,6 +753,11 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
logging.info("Trainable parameters (excluding model.eval modules):")
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
logging.info(f"{name}: {param.shape}")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
else:
|
else:
|
||||||
@ -770,12 +805,14 @@ def run(rank, world_size, args):
|
|||||||
# sampler_state_dict = checkpoints["sampler"]
|
# sampler_state_dict = checkpoints["sampler"]
|
||||||
# else:
|
# else:
|
||||||
# sampler_state_dict = None
|
# sampler_state_dict = None
|
||||||
|
sampler_state_dict = None
|
||||||
# TODO: load sampler state dict
|
# TODO: load sampler state dict
|
||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
# valid_cuts = multi_dataset.dev_cuts()
|
||||||
|
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
@ -818,14 +855,20 @@ def run(rank, world_size, args):
|
|||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
client_state={"sampler": train_dl.sampler.state_dict()},
|
client_state={},
|
||||||
|
exclude_frozen_parameters=True
|
||||||
)
|
)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
|
# save sampler state dict into checkpoint
|
||||||
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
|
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
|
||||||
|
|
||||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user