Merge 70f13e54d814761432acc1c23e9ef4ffd566df41 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Yifan Yang 2025-07-19 14:22:04 +08:00 committed by GitHub
commit c103dbef78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 2089 additions and 143 deletions

View File

@ -3,7 +3,7 @@
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
# Included Training Sets
# Included Training Dataset
1. THCHS-30
2. AiShell-{1,2,4}
3. ST-CMDS
@ -14,7 +14,7 @@ This recipe includes scripts for training Zipformer model using multiple Chinese
8. WeNetSpeech
9. KeSpeech-ASR
|Datset| Number of hours| URL|
|Dataset| Number of hours| URL|
|---|---:|---|
|**TOTAL**|14,106|---|
|THCHS-30|35|https://www.openslr.org/18/|

View File

@ -99,7 +99,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Transducer Greedy Offline | 21.67 | 23.43 | 1.22 | 1.31 | 3.17 | 3.27 | 14.64 | 2.42 | 1.99 | 5.00 | 2.29 | 5.98 | 5.15 | 5.85 | 6.89 |
Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-xl
@ -152,7 +152,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 |
| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 |
| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 |
@ -193,7 +193,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |
@ -226,7 +226,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |

View File

@ -109,6 +109,25 @@ class AsrDataModule:
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--num-cuts-for-bins-estimate",
type=int,
default=10000,
help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly"
"at a longer init cost.",
)
group.add_argument(
"--quadratic-duration",
type=float,
default=None,
help="When set, it adds an extra penalty that's quadratic"
"in size w.r.t. a cuts duration. This helps get a more"
"even GPU utilization across different input lengths when"
"models have quadratic input complexity.0 Set between 15"
"and 40 for transformers.",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
@ -205,6 +224,8 @@ class AsrDataModule:
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -295,11 +316,15 @@ class AsrDataModule:
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -307,6 +332,8 @@ class AsrDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -330,7 +357,12 @@ class AsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def valid_dataloaders(
self,
cuts_valid: CutSet,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -355,6 +387,8 @@ class AsrDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

4
egs/speech_llm/ASR_LLM/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
models
train*.sh
decode*.sh
sync*.sh

View File

@ -37,6 +37,15 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
mkdir data/fbank/wenetspeech
mv data/fbank/cuts_L_fixed.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/cuts_DEV_fixed.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/cuts_TEST_MEETING.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/cuts_TEST_NET.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/L_split_100 data/fbank/wenetspeech/
mv data/fbank/feats_DEV.lca data/fbank/wenetspeech/
mv data/fbank/feats_TEST_MEETING.lca data/fbank/wenetspeech/
mv data/fbank/feats_TEST_NET.lca data/fbank/wenetspeech/
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -46,4 +55,5 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
mkdir data_speechio
huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank
rm -rf data_speechio
fi

View File

@ -5,7 +5,7 @@
"loss_scale_window": 100,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 0.01
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 1,

View File

@ -194,10 +194,10 @@ class SPEECH_LLM(nn.Module):
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
fbank: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
labels: torch.LongTensor,
):
encoder_outs = self.encoder(fbank)

View File

@ -15,13 +15,10 @@
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
from typing import Dict, List
from typing import Dict
import lhotse
from lhotse import CutSet, load_manifest_lazy
@ -50,103 +47,13 @@ class MultiDataset:
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# THCHS-30
logging.info("Loading THCHS-30 in lazy mode")
thchs_30_cuts = load_manifest_lazy(
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
)
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 in lazy mode")
aishell_4_L_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
)
aishell_4_M_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
)
aishell_4_S_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
)
# ST-CMDS
logging.info("Loading ST-CMDS in lazy mode")
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
# Primewords
logging.info("Loading Primewords in lazy mode")
primewords_cuts = load_manifest_lazy(
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData in lazy mode")
magicdata_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting in lazy mode")
alimeeting_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech in lazy mode")
kespeech_1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
)
kespeech_2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
)
return CutSet.mux(
thchs_30_cuts,
aishell_cuts,
aishell_2_cuts,
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
weights=[
len(thchs_30_cuts),
len(aishell_cuts),
len(aishell_2_cuts),
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
],
)
return wenetspeech_L_cuts
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
@ -247,8 +154,7 @@ class MultiDataset:
}
def aishell_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
logging.info("Loading Aishell-1 in lazy mode")
logging.info("Loading Aishell-1 train set in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
)
@ -256,8 +162,7 @@ class MultiDataset:
return aishell_cuts
def aishell_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
logging.info("Loading Aishell set in lazy mode")
logging.info("Loading Aishell-1 dev set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
@ -265,8 +170,7 @@ class MultiDataset:
return aishell_dev_cuts
def aishell_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
logging.info("Loading Aishell set in lazy mode")
logging.info("Loading Aishell-1 test set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
@ -276,8 +180,7 @@ class MultiDataset:
}
def aishell2_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
logging.info("Loading Aishell-2 in lazy mode")
logging.info("Loading Aishell-2 train set in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
@ -285,8 +188,7 @@ class MultiDataset:
return aishell_2_cuts
def aishell2_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
logging.info("Loading Aishell-2 set in lazy mode")
logging.info("Loading Aishell-2 dev set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
@ -294,8 +196,7 @@ class MultiDataset:
return aishell2_dev_cuts
def aishell2_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
logging.info("Loading Aishell-2 set in lazy mode")
logging.info("Loading Aishell-2 test set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
@ -304,9 +205,28 @@ class MultiDataset:
"aishell2_test": aishell2_test_cuts,
}
def wenetspeech_dev_cuts(self) -> CutSet:
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)
return {
"wenetspeech-dev": wenetspeech_dev_cuts,
}
def wenetspeech_test_net_cuts(self) -> CutSet:
logging.info("Loading WeNetSpeech-net test set in lazy mode")
wenetspeech_test_net_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
return {
"wenetspeech-net_test": wenetspeech_test_net_cuts,
}
def wenetspeech_test_meeting_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
logging.info("Loading WeNetSpeech set in lazy mode")
logging.info("Loading WeNetSpeech-meeting test set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
)
@ -316,7 +236,7 @@ class MultiDataset:
}
def speechio_test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
logging.info("Loading SpeechIO test set in lazy mode")
start_index = 0
end_index = 26
dataset_parts = []

View File

@ -1,11 +1,6 @@
k2
kaldialign
git+https://github.com/lhotse-speech/lhotse
sentencepiece
pypinyin
tensorboard
librosa
deepspeed
transformers>=4.37.0
flash-attn
peft
openai-whisper

View File

@ -18,18 +18,6 @@
# limitations under the License.
"""
Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
@ -39,7 +27,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--use-lora True \
--unfreeze-llm True
"""
import argparse
@ -333,7 +322,6 @@ def compute_loss(
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"]
@ -378,7 +366,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
tokenizer: AutoTokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,

View File

@ -0,0 +1 @@
../whisper_llm_zh/asr_datamodule.py

View File

@ -0,0 +1,541 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
# 2025 Yifan Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python3 ./zipformer_llm_zh/decode.py \
--max-duration 80 \
--exp-dir zipformer_llm_zh/exp \
--speech-encoder-path-or-name models/zipformer/epoch-999.pt \
--llm-path-or-name models/qwen \
--epoch 999 \
--avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True \
--dataset aishell
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import transformers
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import (
DEFAULT_SPEECH_TOKEN,
_to_int_tuple,
add_model_arguments,
get_encoder_embed,
get_encoder_model,
get_params,
load_model_params,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from zipformer import Zipformer2
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="beam-search",
help="""Decoding method.
Supported values are:
- beam-search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=1,
help="beam size for beam search decoding",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--dataset",
type=str,
default="aishell",
choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
help="The dataset to decode",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: "beam-search"
- value: A list of lists. Each sublist is a list of token IDs.
Args:
params:
It is returned by :func:`get_params`.
model:
The neural model.
batch:
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
Returns:
Return a dict, whose key may be "beam-search".
"""
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int = 128,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
max_length=max_len,
truncation=True,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
device = model.llm.device
feature = batch["inputs"]
assert feature.ndim == 3
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"]
messages = [
[
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""},
]
] * len(feature)
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
generated_ids = model.decode(
feature.to(device),
feature_lens.to(device),
input_ids.to(device, dtype=torch.long),
attention_mask.to(device),
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return {"beam-search": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
The dataloader.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "beam-search".
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_text))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / f"{params.method}"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(
params.res_dir
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
)
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
logging.info(f"device: {device}")
speech_encoder_embed = get_encoder_embed(params)
speech_encoder = get_encoder_model(params)
load_model_params(
params.speech_encoder_path_or_name, speech_encoder_embed, "encoder_embed"
)
load_model_params(params.speech_encoder_path_or_name, speech_encoder, "encoder")
speech_encoder_dim = max(_to_int_tuple(params.encoder_dim))
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM(
speech_encoder_embed,
speech_encoder,
llm,
encoder_projector,
)
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
# torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
map_location="cpu",
)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
def remove_long_utt(c: Cut):
# Keep only utterances with duration in 30 seconds
#
if c.duration > 30.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
if params.dataset == "aishell":
test_sets_cuts = multi_dataset.aishell_test_cuts()
elif params.dataset == "speechio":
test_sets_cuts = multi_dataset.speechio_test_cuts()
elif params.dataset == "wenetspeech_test_meeting":
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
else:
test_sets_cuts = multi_dataset.test_cuts()
test_sets = test_sets_cuts.keys()
test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
for cuts_name in test_sets
]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1 @@
../whisper_llm_zh/ds_config_zero1.json

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -0,0 +1,379 @@
from typing import Tuple
import torch
from encoder_interface import EncoderInterface
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from icefall.utils import make_pad_mask
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
batch_size, seq_len, feat_dim = x.size()
num_padding_frames = (
self.downsample_rate - seq_len % self.downsample_rate
) % self.downsample_rate
if num_padding_frames > 0:
x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames))
seq_len = x.size(1)
x = x.contiguous()
x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
llm: nn.Module,
encoder_projector: nn.Module,
):
super().__init__()
self.encoder_embed = encoder_embed
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_speech_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_speech_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def ctc_compress(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_id: int = 0,
) -> torch.Tensor:
"""
Remove frames from encoder_out where CTC argmax predicts blank.
Args:
encoder_out: Tensor of shape (N, T, C), encoder output.
encoder_out_lens: Tensor of shape (N,), lengths before padding.
blank_id: CTC blank token ID (default: 0).
Returns:
Compressed CTC output of shape (N, T', C).
"""
# 1. Compute CTC argmax predictions
ctc_output = self.ctc_output(encoder_out)
ctc_preds = ctc_output.argmax(dim=-1)
# 2. Create non-blank, non-pad mask
padding_mask = make_pad_mask(encoder_out_lens)
non_blank_mask = (ctc_preds != blank_id) & (~padding_mask)
# 3. Compute lengths after compress
compressed_lens = non_blank_mask.sum(dim=1)
max_len = compressed_lens.max().item()
# 4. Pre-pad output
pad_lens_list = (
torch.full_like(
compressed_lens,
max_len,
device=ctc_output.device,
)
- compressed_lens
)
max_pad_len = int(pad_lens_list.max())
padded_ctc_output = torch.nn.functional.pad(ctc_output, [0, 0, 0, max_pad_len])
# 5. Create final mask
padding_mask = ~make_pad_mask(pad_lens_list)
total_mask = torch.concat([non_blank_mask, padding_mask], dim=1)
# 6. Apply mask and reshape
compressed_output = padded_ctc_output[total_mask].reshape(
ctc_output.shape[0], -1, ctc_output.shape[2]
)
return compressed_output
def forward(
self,
fbank: torch.Tensor,
fbank_lens: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
labels: torch.LongTensor,
):
encoder_outs, encoder_out_lens = self.forward_encoder(fbank, fbank_lens)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs, acc
def decode(
self,
fbank: torch.Tensor,
fbank_lens: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
**kwargs,
):
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
_,
position_ids,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float()

View File

@ -0,0 +1 @@
../whisper_llm_zh/multi_dataset.py

View File

@ -0,0 +1,5 @@
librosa
deepspeed>=0.16.9
transformers>=4.37.0
flash-attn
peft

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py