add decode file

This commit is contained in:
Yuekai Zhang 2024-06-04 22:56:57 -07:00
parent b5a906cbbd
commit 3dbbc29429
6 changed files with 230 additions and 153 deletions

27
egs/speech_llm/ASR_LLM/debug.sh Executable file
View File

@ -0,0 +1,27 @@
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install -r whisper/requirements.txt
export CUDA_VISIBLE_DEVICES=0,1
# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
# --max-duration 80 \
# --exp-dir ./whisper_llm_zh/exp_test \
# --speech-encoder-path-or-name tiny \
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
# --manifest-dir data/fbank \
# --deepspeed \
# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
# --use-flash-attn False
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name tiny \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
--epoch 1 --avg 1 \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn False

View File

@ -1,14 +1,16 @@
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install -r whisper/requirements.txt
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
--max-duration 80 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name tiny \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
#pip install -r whisper_llm_zh/requirements.txt
#export CUDA_VISIBLE_DEVICES=0,1
whisper_path=/workspace/asr/icefall_asr_multi-hans-zh_whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
llm_path=/workspace/asr/Qwen1.5-7B-Chat
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 100 \
--exp-dir ./whisper_llm_zh/exp_qwen_7b \
--speech-encoder-path-or-name $whisper_path \
--llm-path-or-name $llm_path \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn False
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json

View File

@ -30,15 +30,6 @@ python3 ./whisper/decode.py \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
# Command for decoding using pretrained models (before fine-tuning):
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
"""
import argparse
@ -70,7 +61,7 @@ from icefall.utils import (
str2bool,
write_error_stats,
)
from train import DEFAULT_SPEECH_TOKEN
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
@ -123,48 +114,27 @@ def average_checkpoints(
return avg
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
def remove_punctuation(text: str or List[str]):
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings without any punctuation.
"""
punctuation = "!,.;:?、!,。;:?《》 "
if isinstance(text, str):
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
result_text.append(t)
return result_text
else:
raise Exception(f"Not support type {type(text)}")
def to_simple(text: str or List[str]):
"""Convert traditional Chinese to simplified Chinese.
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings converted to simplified Chinese.
"""
if isinstance(text, str):
text = convert(text, "zh-cn")
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, "zh-cn")
result_text.append(t)
return result_text
else:
raise Exception(f"Not support type{type(text)}")
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=4,
help="Downsample rate for the encoder projector.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -211,15 +181,6 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
@ -227,13 +188,7 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)
add_model_arguments(parser)
return parser
@ -249,6 +204,7 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -266,8 +222,33 @@ def decode_one_batch(
Returns:
Return a dict, whose key may be "beam-search".
"""
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int = 128,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
padding="max_length",
max_length=max_len,
truncation=True,
)
)
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
dtype = torch.float16
device = torch.device("cuda")
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
@ -288,12 +269,25 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
results = model.decode(feature, params.decoding_options)
hyps = [result.text for result in results]
hyps = remove_punctuation(hyps)
hyps = to_simple(hyps)
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
messages = []
for i, text in enumerate(texts):
message = [
{"role": "system", "content": "你是一个能处理音频的助手。"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": ""},
]
messages.append(message)
input_ids, attention_mask = preprocess(
messages, tokenizer, max_len=128
)
model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# hyps = remove_punctuation(hyps)
# hyps = to_simple(hyps)
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
print(hyps)
return {"beam-search": hyps}
@ -302,6 +296,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -370,6 +365,7 @@ def decode_dataset(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
)
for lm_scale, hyps in hyps_dict.items():
@ -455,16 +451,6 @@ def main():
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
)
options = whisper.DecodingOptions(
task="transcribe",
language="zh",
without_timestamps=True,
beam_size=params.beam_size,
)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()
logging.info("Decoding started")
logging.info(params)
@ -476,49 +462,68 @@ def main():
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype=torch.bfloat16
else:
attn_implementation = "eager"
torch_dtype=torch.float16
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
tokenizer.padding_side = 'left'
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.bos_token_id = tokenizer.bos_token_id
llm.config.eos_token_id = tokenizer.eos_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
)
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
model.load_state_dict(average_checkpoints(filenames), strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=False)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
@ -534,13 +539,14 @@ def main():
# Keep only utterances with duration in 30 seconds
#
if c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
test_sets_cuts = multi_dataset.test_cuts()
# test_sets_cuts = multi_dataset.test_cuts()
test_sets_cuts = multi_dataset.aishell_test_cuts()
test_sets = test_sets_cuts.keys()
test_dls = [
@ -553,6 +559,7 @@ def main():
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)

View File

@ -38,7 +38,7 @@ class SPEECH_LLM(nn.Module):
self.encoder_projector = encoder_projector
self.encoder_outputs_downsample_rate = 4
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels):
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
@ -132,18 +132,45 @@ class SPEECH_LLM(nn.Module):
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
# outputs = self.llm(
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_values=past_key_values,
# inputs_embeds=inputs_embeds,
# use_cache=use_cache,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
# )
# logits = outputs[0]
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
return model_outputs
def decode(self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs:
):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
position_ids=position_ids,
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.tllm.config.pad_token_id
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
]
return generated_ids

View File

@ -267,4 +267,15 @@ class MultiDataset:
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
return aishell_dev_cuts
return aishell_dev_cuts
def aishell_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
return aishell_test_cuts

View File

@ -725,13 +725,16 @@ def run(rank, world_size, args):
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype=torch.bfloat16
else:
attn_implementation = "eager"
torch_dtype=torch.float16
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
tokenizer.padding_side = 'left'