from local

This commit is contained in:
dohe0342 2023-04-11 14:15:44 +09:00
parent de2a41d86c
commit 95c3d5b854
2 changed files with 15 additions and 6 deletions

View File

@ -913,12 +913,21 @@ def compute_loss(
if decode:
model.eval()
with torch.no_grad():
hypos = model.module.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
try:
hypos = model.module.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
except:
hypos = model.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
logging.info(f'ref: {batch["supervisions"]["text"][0]}')
logging.info(f'hyp: {" ".join(hypos[0])}')
model.train()