mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix requirements
This commit is contained in:
parent
09ec0d6553
commit
3ac27d5ad4
@ -17,8 +17,8 @@ export CUDA_VISIBLE_DEVICES=0,1
|
|||||||
|
|
||||||
python3 ./whisper_llm_zh/decode.py \
|
python3 ./whisper_llm_zh/decode.py \
|
||||||
--max-duration 80 \
|
--max-duration 80 \
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
--exp-dir ./whisper_llm_zh/exp_qwen_0.5B \
|
||||||
--speech-encoder-path-or-name tiny \
|
--speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
--epoch 1 --avg 1 \
|
--epoch 1 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
|
@ -291,7 +291,7 @@ def decode_one_batch(
|
|||||||
|
|
||||||
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
||||||
# hyps = remove_punctuation(hyps)
|
# hyps = remove_punctuation(hyps)
|
||||||
# hyps = to_simple(hyps)
|
# hyps = to_simple(hyps)
|
||||||
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||||
@ -377,7 +377,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
# assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_text = normalize_text_alimeeting(ref_text)
|
ref_text = normalize_text_alimeeting(ref_text)
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
|
@ -169,7 +169,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
eos_token_id=self.llm.config.eos_token_id,
|
eos_token_id=self.llm.config.eos_token_id,
|
||||||
pad_token_id=self.llm.config.pad_token_id
|
pad_token_id=self.llm.config.pad_token_id
|
||||||
)
|
)
|
||||||
generated_ids = [
|
# print(generated_ids, input_ids)
|
||||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
# generated_ids = [
|
||||||
]
|
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
|
||||||
|
# ]
|
||||||
return generated_ids
|
return generated_ids
|
@ -1,7 +1,8 @@
|
|||||||
k2
|
k2
|
||||||
kaldialign
|
kaldialign
|
||||||
git+https://github.com/lhotse-speech/lhotse
|
git+https://github.com/lhotse-speech/lhotse
|
||||||
# sentencepiece
|
sentencepiece
|
||||||
|
pypinyin
|
||||||
tensorboard
|
tensorboard
|
||||||
librosa
|
librosa
|
||||||
# git+https://github.com/yuekaizhang/whisper.git
|
# git+https://github.com/yuekaizhang/whisper.git
|
||||||
|
Loading…
x
Reference in New Issue
Block a user