mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
add decode file
This commit is contained in:
parent
b5a906cbbd
commit
3dbbc29429
27
egs/speech_llm/ASR_LLM/debug.sh
Executable file
27
egs/speech_llm/ASR_LLM/debug.sh
Executable file
@ -0,0 +1,27 @@
|
|||||||
|
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
||||||
|
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# pip install -r whisper/requirements.txt
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1
|
||||||
|
# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
||||||
|
# --max-duration 80 \
|
||||||
|
# --exp-dir ./whisper_llm_zh/exp_test \
|
||||||
|
# --speech-encoder-path-or-name tiny \
|
||||||
|
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
|
# --manifest-dir data/fbank \
|
||||||
|
# --deepspeed \
|
||||||
|
# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
# --use-flash-attn False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
python3 ./whisper_llm_zh/decode.py \
|
||||||
|
--max-duration 80 \
|
||||||
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
|
--speech-encoder-path-or-name tiny \
|
||||||
|
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
|
--epoch 1 --avg 1 \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn False
|
@ -1,14 +1,16 @@
|
|||||||
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
export PYTHONPATH=$PYTHONPATH:/workspace/asr/icefall
|
||||||
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
# pip install -r whisper/requirements.txt
|
#pip install -r whisper_llm_zh/requirements.txt
|
||||||
export CUDA_VISIBLE_DEVICES=0,1
|
#export CUDA_VISIBLE_DEVICES=0,1
|
||||||
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
|
||||||
--max-duration 80 \
|
whisper_path=/workspace/asr/icefall_asr_multi-hans-zh_whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
llm_path=/workspace/asr/Qwen1.5-7B-Chat
|
||||||
--speech-encoder-path-or-name tiny \
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
--max-duration 100 \
|
||||||
|
--exp-dir ./whisper_llm_zh/exp_qwen_7b \
|
||||||
|
--speech-encoder-path-or-name $whisper_path \
|
||||||
|
--llm-path-or-name $llm_path \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json
|
||||||
--use-flash-attn False
|
|
@ -30,15 +30,6 @@ python3 ./whisper/decode.py \
|
|||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--beam-size 10 --max-duration 50
|
--beam-size 10 --max-duration 50
|
||||||
|
|
||||||
# Command for decoding using pretrained models (before fine-tuning):
|
|
||||||
|
|
||||||
python3 ./whisper/decode.py \
|
|
||||||
--exp-dir whisper/exp_large_v2 \
|
|
||||||
--model-name large-v2 \
|
|
||||||
--epoch -1 --avg 1 \
|
|
||||||
--remove-whisper-encoder-input-length-restriction False \
|
|
||||||
--beam-size 10 --max-duration 50
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -70,7 +61,7 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
from train import DEFAULT_SPEECH_TOKEN
|
||||||
|
|
||||||
def average_checkpoints(
|
def average_checkpoints(
|
||||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
@ -123,48 +114,27 @@ def average_checkpoints(
|
|||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-path-or-name",
|
||||||
|
type=str,
|
||||||
|
default="/workspace/asr/Qwen1.5-0.5B-Chat",
|
||||||
|
help="Path or name of the large language model.",
|
||||||
|
)
|
||||||
|
|
||||||
def remove_punctuation(text: str or List[str]):
|
parser.add_argument(
|
||||||
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
"--speech-encoder-path-or-name",
|
||||||
|
type=str,
|
||||||
Args:
|
default="whisper-large-v2",
|
||||||
text: It can be a string or a list of strings.
|
help="Path or name of the speech encoder.",
|
||||||
Returns:
|
)
|
||||||
Return a string or a list of strings without any punctuation.
|
|
||||||
"""
|
|
||||||
punctuation = "!,.;:?、!,。;:?《》 "
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
|
||||||
return text
|
|
||||||
elif isinstance(text, list):
|
|
||||||
result_text = []
|
|
||||||
for t in text:
|
|
||||||
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
|
||||||
result_text.append(t)
|
|
||||||
return result_text
|
|
||||||
else:
|
|
||||||
raise Exception(f"Not support type {type(text)}")
|
|
||||||
|
|
||||||
|
|
||||||
def to_simple(text: str or List[str]):
|
|
||||||
"""Convert traditional Chinese to simplified Chinese.
|
|
||||||
Args:
|
|
||||||
text: It can be a string or a list of strings.
|
|
||||||
Returns:
|
|
||||||
Return a string or a list of strings converted to simplified Chinese.
|
|
||||||
"""
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = convert(text, "zh-cn")
|
|
||||||
return text
|
|
||||||
elif isinstance(text, list):
|
|
||||||
result_text = []
|
|
||||||
for t in text:
|
|
||||||
t = convert(t, "zh-cn")
|
|
||||||
result_text.append(t)
|
|
||||||
return result_text
|
|
||||||
else:
|
|
||||||
raise Exception(f"Not support type{type(text)}")
|
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-projector-ds-rate",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Downsample rate for the encoder projector.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -211,15 +181,6 @@ def get_parser():
|
|||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-name",
|
|
||||||
type=str,
|
|
||||||
default="large-v2",
|
|
||||||
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
|
|
||||||
help="""The model name to use.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--remove-whisper-encoder-input-length-restriction",
|
"--remove-whisper-encoder-input-length-restriction",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -227,13 +188,7 @@ def get_parser():
|
|||||||
help="replace whisper encoder forward method to remove input length restriction",
|
help="replace whisper encoder forward method to remove input length restriction",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
add_model_arguments(parser)
|
||||||
"--use-distill-whisper",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to use architecture of distill whisper.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -249,6 +204,7 @@ def get_params() -> AttributeDict:
|
|||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -266,8 +222,33 @@ def decode_one_batch(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "beam-search".
|
Return a dict, whose key may be "beam-search".
|
||||||
"""
|
"""
|
||||||
|
def preprocess(
|
||||||
|
messages,
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
max_len: int = 128,
|
||||||
|
) -> Dict:
|
||||||
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
|
texts = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
texts.append(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
msg,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_len,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = torch.device("cuda")
|
device = model.device
|
||||||
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -288,12 +269,25 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_len = supervisions["num_frames"]
|
feature_len = supervisions["num_frames"]
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
feature_len = feature_len.to(device, dtype=dtype)
|
||||||
results = model.decode(feature, params.decoding_options)
|
|
||||||
hyps = [result.text for result in results]
|
|
||||||
|
|
||||||
hyps = remove_punctuation(hyps)
|
messages = []
|
||||||
hyps = to_simple(hyps)
|
for i, text in enumerate(texts):
|
||||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
message = [
|
||||||
|
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||||
|
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
{"role": "assistant", "content": ""},
|
||||||
|
]
|
||||||
|
messages.append(message)
|
||||||
|
input_ids, attention_mask = preprocess(
|
||||||
|
messages, tokenizer, max_len=128
|
||||||
|
)
|
||||||
|
|
||||||
|
model_outputs = model.decode(feature, input_ids.to(device, dtype=torch.LongTensor), attention_mask.to(device))
|
||||||
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
# hyps = remove_punctuation(hyps)
|
||||||
|
# hyps = to_simple(hyps)
|
||||||
|
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||||
print(hyps)
|
print(hyps)
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
@ -302,6 +296,7 @@ def decode_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -370,6 +365,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
@ -455,16 +451,6 @@ def main():
|
|||||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
options = whisper.DecodingOptions(
|
|
||||||
task="transcribe",
|
|
||||||
language="zh",
|
|
||||||
without_timestamps=True,
|
|
||||||
beam_size=params.beam_size,
|
|
||||||
)
|
|
||||||
params.decoding_options = options
|
|
||||||
params.cleaner = BasicTextNormalizer()
|
|
||||||
params.normalizer = Normalizer()
|
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -476,39 +462,58 @@ def main():
|
|||||||
|
|
||||||
if params.remove_whisper_encoder_input_length_restriction:
|
if params.remove_whisper_encoder_input_length_restriction:
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
if params.use_distill_whisper:
|
|
||||||
replace_whisper_decoder_forward()
|
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||||
model = whisper.load_model(params.model_name, "cpu")
|
speech_encoder = whisper_model.encoder
|
||||||
if params.epoch > 0:
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
|
|
||||||
|
if params.use_flash_attn:
|
||||||
|
attn_implementation = "flash_attention_2"
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_implementation = "eager"
|
||||||
|
torch_dtype=torch.float16
|
||||||
|
|
||||||
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
params.llm_path_or_name,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
special_tokens_dict = {
|
||||||
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
|
}
|
||||||
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
llm.config.bos_token_id = tokenizer.bos_token_id
|
||||||
|
llm.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||||
|
|
||||||
|
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
|
||||||
|
|
||||||
|
model = SPEECH_LLM(
|
||||||
|
speech_encoder,
|
||||||
|
llm,
|
||||||
|
encoder_projector,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
)
|
)
|
||||||
if "model" not in checkpoint:
|
assert "model" not in checkpoint
|
||||||
# deepspeed converted checkpoint only contains model state_dict
|
# deepspeed converted checkpoint only contains model state_dict
|
||||||
filenames = [
|
filenames = [
|
||||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||||
for epoch in range(start, params.epoch + 1)
|
for epoch in range(start, params.epoch + 1)
|
||||||
]
|
]
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
model.load_state_dict(average_checkpoints(filenames), strict=False)
|
||||||
else:
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# save checkpoints
|
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
torch.save(model.state_dict(), filename)
|
torch.save(model.state_dict(), filename)
|
||||||
else:
|
else:
|
||||||
@ -516,7 +521,7 @@ def main():
|
|||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
)
|
)
|
||||||
if "model" not in checkpoint:
|
if "model" not in checkpoint:
|
||||||
model.load_state_dict(checkpoint, strict=True)
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
else:
|
else:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -534,13 +539,14 @@ def main():
|
|||||||
# Keep only utterances with duration in 30 seconds
|
# Keep only utterances with duration in 30 seconds
|
||||||
#
|
#
|
||||||
if c.duration > 30.0:
|
if c.duration > 30.0:
|
||||||
# logging.warning(
|
logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
test_sets_cuts = multi_dataset.test_cuts()
|
# test_sets_cuts = multi_dataset.test_cuts()
|
||||||
|
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_dls = [
|
test_dls = [
|
||||||
@ -553,6 +559,7 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||||
|
@ -38,7 +38,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.encoder_projector = encoder_projector
|
self.encoder_projector = encoder_projector
|
||||||
self.encoder_outputs_downsample_rate = 4
|
self.encoder_outputs_downsample_rate = 4
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels):
|
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
||||||
@ -132,18 +132,45 @@ class SPEECH_LLM(nn.Module):
|
|||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
|
|
||||||
# outputs = self.llm(
|
|
||||||
# attention_mask=attention_mask,
|
|
||||||
# position_ids=position_ids,
|
|
||||||
# past_key_values=past_key_values,
|
|
||||||
# inputs_embeds=inputs_embeds,
|
|
||||||
# use_cache=use_cache,
|
|
||||||
# output_attentions=output_attentions,
|
|
||||||
# output_hidden_states=output_hidden_states,
|
|
||||||
# return_dict=return_dict,
|
|
||||||
# )
|
|
||||||
# logits = outputs[0]
|
|
||||||
|
|
||||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||||
|
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def decode(self,
|
||||||
|
fbank: torch.Tensor = None,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs:
|
||||||
|
):
|
||||||
|
|
||||||
|
encoder_outs = self.encoder(fbank)
|
||||||
|
# downsample encoder_outs by 4
|
||||||
|
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||||
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
|
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
||||||
|
speech_features, inputs_embeds, input_ids, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
model_outputs = self.llm.generate(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||||
|
num_beams=kwargs.get("num_beams", 1),
|
||||||
|
do_sample=kwargs.get("do_sample", False),
|
||||||
|
min_length=kwargs.get("min_length", 1),
|
||||||
|
top_p=kwargs.get("top_p", 1.0),
|
||||||
|
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||||
|
length_penalty=kwargs.get("length_penalty", 1.0),
|
||||||
|
temperature=kwargs.get("temperature", 1.0),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
bos_token_id=self.llm.config.bos_token_id,
|
||||||
|
eos_token_id=self.llm.config.eos_token_id,
|
||||||
|
pad_token_id=self.tllm.config.pad_token_id
|
||||||
|
)
|
||||||
|
generated_ids = [
|
||||||
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return generated_ids
|
@ -268,3 +268,14 @@ class MultiDataset:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return aishell_dev_cuts
|
return aishell_dev_cuts
|
||||||
|
|
||||||
|
def aishell_test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
|
# AISHELL
|
||||||
|
logging.info("Loading Aishell set in lazy mode")
|
||||||
|
aishell_test_cuts = load_manifest_lazy(
|
||||||
|
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
return aishell_test_cuts
|
@ -725,13 +725,16 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
|
torch_dtype=torch.float16
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
params.llm_path_or_name,
|
params.llm_path_or_name,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
tokenizer.padding_side = 'left'
|
tokenizer.padding_side = 'left'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user