fix debug

This commit is contained in:
Yuekai Zhang 2024-06-07 09:53:40 +08:00
parent 8bbd06112a
commit 68b99f456f

View File

@ -290,11 +290,9 @@ def decode_one_batch(
) )
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
# hyps = remove_punctuation(hyps)
# hyps = to_simple(hyps)
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
print(hyps) print(hyps)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -381,7 +379,7 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
# assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text) ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()