from local

This commit is contained in:
dohe0342 2023-01-09 19:51:23 +09:00
parent e3aed2e219
commit c21ab480fc
2 changed files with 1 additions and 1 deletions

View File

@ -255,7 +255,7 @@ class Interformer(nn.Module):
with torch.cuda.amp.autocast(enabled=False):
for inter_output, pt_output in zip(inter_layer_outputs, pt_layer_outputs):
mse_loss = self.mse(
mse_loss = self.mse(inter_output, pt_output)
'''
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),