fix down sample method

This commit is contained in:
root 2024-06-06 13:25:29 +00:00 committed by Yuekai Zhang
parent 796663066f
commit 412e926941
4 changed files with 91 additions and 15 deletions

View File

@ -4,7 +4,7 @@ export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
# pip install -r whisper/requirements.txt
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
--max-duration 1 \
--max-duration 20 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name tiny \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \

View File

@ -1 +0,0 @@
../../../aishell/ASR/whisper/ds_config_zero1.json

View File

@ -0,0 +1,38 @@
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 100,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 0.01
},
"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 1e-4,
"warmup_num_steps": 100
}
},
"gradient_accumulation_steps": 1,
"gradient_clipping": 5,
"steps_per_print": 50,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}

View File

@ -5,14 +5,25 @@ from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
def __init__(self, encoder_dim, llm_dim):
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
def __init__(self, encoder_dim, llm_dim, downsample_rate=4):
super().__init__()
self.linear1 = nn.Linear(encoder_dim, llm_dim)
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.downsample_rate
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
@ -124,7 +135,7 @@ class SPEECH_LLM(nn.Module):
):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
@ -138,8 +149,10 @@ class SPEECH_LLM(nn.Module):
#print("speech_features", speech_features.shape)
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
return model_outputs
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
return model_outputs, acc
def decode(self,
@ -151,7 +164,7 @@ class SPEECH_LLM(nn.Module):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
@ -178,3 +191,23 @@ class SPEECH_LLM(nn.Module):
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
# ]
return generated_ids
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type

View File

@ -420,7 +420,11 @@ def compute_loss(
# first get the indices of the tokens
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
target_ids[mask_indices[0], :mask_indices[1]+4] = IGNORE_TOKEN_ID
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
target_ids[row, :col+4] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
@ -496,13 +500,13 @@ def compute_loss(
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
model_outpus = model(
model_outputs, acc = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
loss = model_outpus.loss
loss = model_outputs.loss
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -513,6 +517,7 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = acc
return loss, info
@ -731,7 +736,8 @@ def run(rank, world_size, args):
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype=torch.bfloat16
# torch_dtype=torch.bfloat16
torch_dtype=torch.float16
else:
attn_implementation = "eager"