fix requirements

This commit is contained in:
Yuekai Zhang 2024-06-06 16:24:27 +08:00
parent 09ec0d6553
commit 3ac27d5ad4
4 changed files with 10 additions and 8 deletions

View File

@ -17,8 +17,8 @@ export CUDA_VISIBLE_DEVICES=0,1
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name tiny \
--exp-dir ./whisper_llm_zh/exp_qwen_0.5B \
--speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
--epoch 1 --avg 1 \
--manifest-dir data/fbank \

View File

@ -291,7 +291,7 @@ def decode_one_batch(
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
# hyps = remove_punctuation(hyps)
# hyps = to_simple(hyps)
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
@ -377,7 +377,7 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
# assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()

View File

@ -169,7 +169,8 @@ class SPEECH_LLM(nn.Module):
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
]
# print(generated_ids, input_ids)
# generated_ids = [
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
# ]
return generated_ids

View File

@ -1,7 +1,8 @@
k2
kaldialign
git+https://github.com/lhotse-speech/lhotse
# sentencepiece
sentencepiece
pypinyin
tensorboard
librosa
# git+https://github.com/yuekaizhang/whisper.git