mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add decode file
This commit is contained in:
parent
b5a906cbbd
commit
3dbbc29429
27
egs/speech_llm/ASR_LLM/debug.sh
Executable file
27
egs/speech_llm/ASR_LLM/debug.sh
Executable file
@ -0,0 +1,27 @@
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
||||
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install -r whisper/requirements.txt
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
||||
# --max-duration 80 \
|
||||
# --exp-dir ./whisper_llm_zh/exp_test \
|
||||
# --speech-encoder-path-or-name tiny \
|
||||
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||
# --manifest-dir data/fbank \
|
||||
# --deepspeed \
|
||||
# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
# --use-flash-attn False
|
||||
|
||||
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name tiny \
|
||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||
--epoch 1 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn False
|
@ -1,14 +1,16 @@
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
||||
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install -r whisper/requirements.txt
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name tiny \
|
||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall
|
||||
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
#pip install -r whisper_llm_zh/requirements.txt
|
||||
#export CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
whisper_path=/workspace/asr/icefall_asr_multi-hans-zh_whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
llm_path=/workspace/asr/Qwen1.5-7B-Chat
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 100 \
|
||||
--exp-dir ./whisper_llm_zh/exp_qwen_7b \
|
||||
--speech-encoder-path-or-name $whisper_path \
|
||||
--llm-path-or-name $llm_path \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn False
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json
|
@ -30,15 +30,6 @@ python3 ./whisper/decode.py \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
# Command for decoding using pretrained models (before fine-tuning):
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -70,7 +61,7 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
@ -123,48 +114,27 @@ def average_checkpoints(
|
||||
|
||||
return avg
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--llm-path-or-name",
|
||||
type=str,
|
||||
default="/workspace/asr/Qwen1.5-0.5B-Chat",
|
||||
help="Path or name of the large language model.",
|
||||
)
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings without any punctuation.
|
||||
"""
|
||||
punctuation = "!,.;:?、!,。;:?《》 "
|
||||
if isinstance(text, str):
|
||||
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type {type(text)}")
|
||||
|
||||
|
||||
def to_simple(text: str or List[str]):
|
||||
"""Convert traditional Chinese to simplified Chinese.
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings converted to simplified Chinese.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = convert(text, "zh-cn")
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = convert(t, "zh-cn")
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f"Not support type{type(text)}")
|
||||
parser.add_argument(
|
||||
"--speech-encoder-path-or-name",
|
||||
type=str,
|
||||
default="whisper-large-v2",
|
||||
help="Path or name of the speech encoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -211,15 +181,6 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove-whisper-encoder-input-length-restriction",
|
||||
type=str2bool,
|
||||
@ -227,13 +188,7 @@ def get_parser():
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-distill-whisper",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use architecture of distill whisper.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
@ -249,6 +204,7 @@ def get_params() -> AttributeDict:
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -266,8 +222,33 @@ def decode_one_batch(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_len: int = 128,
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
padding="max_length",
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
device = model.device
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
@ -288,12 +269,25 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
results = model.decode(feature, params.decoding_options)
|
||||
hyps = [result.text for result in results]
|
||||
|
||||
hyps = remove_punctuation(hyps)
|
||||
hyps = to_simple(hyps)
|
||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
messages = []
|
||||
for i, text in enumerate(texts):
|
||||
message = [
|
||||
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
messages.append(message)
|
||||
input_ids, attention_mask = preprocess(
|
||||
messages, tokenizer, max_len=128
|
||||
)
|
||||
|
||||
model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device))
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
# hyps = remove_punctuation(hyps)
|
||||
# hyps = to_simple(hyps)
|
||||
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
@ -302,6 +296,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: AutoTokenizer,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -370,6 +365,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
@ -455,16 +451,6 @@ def main():
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
)
|
||||
|
||||
options = whisper.DecodingOptions(
|
||||
task="transcribe",
|
||||
language="zh",
|
||||
without_timestamps=True,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
params.decoding_options = options
|
||||
params.cleaner = BasicTextNormalizer()
|
||||
params.normalizer = Normalizer()
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
@ -476,49 +462,68 @@ def main():
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
if params.use_distill_whisper:
|
||||
replace_whisper_decoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
torch_dtype=torch.bfloat16
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype=torch.float16
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
tokenizer.padding_side = 'left'
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.bos_token_id = tokenizer.bos_token_id
|
||||
llm.config.eos_token_id = tokenizer.eos_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
encoder_projector,
|
||||
)
|
||||
|
||||
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
assert "model" not in checkpoint
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames), strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
@ -534,13 +539,14 @@ def main():
|
||||
# Keep only utterances with duration in 30 seconds
|
||||
#
|
||||
if c.duration > 30.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
# test_sets_cuts = multi_dataset.test_cuts()
|
||||
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
@ -553,6 +559,7 @@ def main():
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
@ -38,7 +38,7 @@ class SPEECH_LLM(nn.Module):
|
||||
self.encoder_projector = encoder_projector
|
||||
self.encoder_outputs_downsample_rate = 4
|
||||
|
||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
||||
@ -132,18 +132,45 @@ class SPEECH_LLM(nn.Module):
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
# outputs = self.llm(
|
||||
# attention_mask=attention_mask,
|
||||
# position_ids=position_ids,
|
||||
# past_key_values=past_key_values,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# use_cache=use_cache,
|
||||
# output_attentions=output_attentions,
|
||||
# output_hidden_states=output_hidden_states,
|
||||
# return_dict=return_dict,
|
||||
# )
|
||||
# logits = outputs[0]
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
|
||||
return model_outputs
|
||||
|
||||
|
||||
def decode(self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs:
|
||||
):
|
||||
|
||||
encoder_outs = self.encoder(fbank)
|
||||
# downsample encoder_outs by 4
|
||||
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask
|
||||
)
|
||||
|
||||
model_outputs = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
do_sample=kwargs.get("do_sample", False),
|
||||
min_length=kwargs.get("min_length", 1),
|
||||
top_p=kwargs.get("top_p", 1.0),
|
||||
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||
length_penalty=kwargs.get("length_penalty", 1.0),
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
bos_token_id=self.llm.config.bos_token_id,
|
||||
eos_token_id=self.llm.config.eos_token_id,
|
||||
pad_token_id=self.tllm.config.pad_token_id
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
||||
]
|
||||
return generated_ids
|
@ -267,4 +267,15 @@ class MultiDataset:
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
return aishell_dev_cuts
|
||||
return aishell_dev_cuts
|
||||
|
||||
def aishell_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
return aishell_test_cuts
|
@ -725,13 +725,16 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
torch_dtype=torch.bfloat16
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype=torch.float16
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
tokenizer.padding_side = 'left'
|
||||
|
Loading…
x
Reference in New Issue
Block a user