from local

This commit is contained in:
dohe0342 2022-12-10 15:07:10 +09:00
parent c12b23a73b
commit 8695a97ac5
2 changed files with 13 additions and 0 deletions

View File

@ -835,6 +835,19 @@ def compute_loss(
assert loss.requires_grad == is_training
if decode:
model.eval()
with torch.no_grad():
hypos = model.module.decode(
x=feature,
x_lens=feature_lens,
y=y,
sp=sp
)
logging.info(f'ref: {batch["supervisions"]["text"][0]}')
logging.info(f'hyp: {" ".join(hypos[0])}')
model.train()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()