mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix debug
This commit is contained in:
parent
8bbd06112a
commit
68b99f456f
@ -290,11 +290,9 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
||||||
# hyps = remove_punctuation(hyps)
|
|
||||||
# hyps = to_simple(hyps)
|
|
||||||
# hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
|
||||||
print(hyps)
|
print(hyps)
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
@ -381,7 +379,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
# assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_text = normalize_text_alimeeting(ref_text)
|
ref_text = normalize_text_alimeeting(ref_text)
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user