add logging

This commit is contained in:
root 2024-06-11 09:20:59 +00:00 committed by Yuekai Zhang
parent 4ebccebcc0
commit b26d3fa596

View File

@ -297,7 +297,7 @@ def decode_one_batch(
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return {"beam-search": hyps} return {"beam-search": hyps}
@ -383,6 +383,8 @@ def decode_dataset(
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text) ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)