add decode log

This commit is contained in:
Yuekai Zhang 2024-06-06 22:01:17 +08:00
parent 412e926941
commit 8bbd06112a

View File

@ -296,6 +296,10 @@ def decode_one_batch(
# hyps = to_simple(hyps) # hyps = to_simple(hyps)
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps] # hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
print(hyps) print(hyps)
texts = batch["supervisions"]["text"]
for i, text in enumerate(texts):
print(f"ref: {text}")
print(f"hyp: {hyps[i]}")
return {"beam-search": hyps} return {"beam-search": hyps}
@ -476,7 +480,8 @@ def main():
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
torch_dtype=torch.bfloat16 # torch_dtype=torch.bfloat16
torch_dtype=torch.float16
else: else:
attn_implementation = "eager" attn_implementation = "eager"