just remove unnecessary torch.sum

This commit is contained in:
PingFeng Luo 2022-02-22 14:41:50 +08:00
parent 2332ba312d
commit e558be1cb7

View File

@ -122,4 +122,4 @@ class Transducer(nn.Module):
loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
return torch.sum(loss)
return loss