from local

This commit is contained in:
dohe0342 2023-01-09 19:51:14 +09:00
parent 6e789c319f
commit e3aed2e219
3 changed files with 8 additions and 1 deletions

View File

@ -244,13 +244,19 @@ class Interformer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
): ):
encoder_out, x_lens, layer_outputs = self.pt_encoder(x, encoder_out, x_lens, pt_layer_outputs = self.pt_encoder(x,
x_lens, x_lens,
warmup=warmup, warmup=warmup,
get_layer_output=True get_layer_output=True
) )
inter_layer_outputs = self.inter_encoder(pt_layer_outputs)
loss = 0
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
for inter_output, pt_output in zip(inter_layer_outputs, pt_layer_outputs):
mse_loss = self.mse(
'''
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
am=am.float(), am=am.float(),
@ -263,5 +269,6 @@ class Interformer(nn.Module):
delay_penalty=delay_penalty, delay_penalty=delay_penalty,
return_grad=True, return_grad=True,
) )
'''
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)