from local

This commit is contained in:
dohe0342 2023-01-09 19:53:01 +09:00
parent c21ab480fc
commit f41acc98ef
2 changed files with 1 additions and 0 deletions

View File

@ -256,6 +256,7 @@ class Interformer(nn.Module):
with torch.cuda.amp.autocast(enabled=False):
for inter_output, pt_output in zip(inter_layer_outputs, pt_layer_outputs):
mse_loss = self.mse(inter_output, pt_output)
loss += mse_loss
'''
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),