removing debug log

This commit is contained in:
root 2024-06-11 09:17:31 +00:00 committed by Yuekai Zhang
parent 271536248f
commit 4ebccebcc0

View File

@ -134,7 +134,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=4,
default=1,
help="Downsample rate for the encoder projector.",
)
@ -290,11 +290,6 @@ def decode_one_batch(
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""},
]] * len(feature)
# messages = [[
# {"role": "system", "content": "你是一个能处理音频的助手。"},
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": ""},
# ]] * len(feature)
input_ids, attention_mask = preprocess(
messages, tokenizer, max_len=128
@ -302,13 +297,7 @@ def decode_one_batch(
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
print(hyps)
texts = batch["supervisions"]["text"]
for i, text in enumerate(texts):
print(f"ref: {text}")
print(f"hyp: {hyps[i]}")
return {"beam-search": hyps}