enable torch.where

This commit is contained in:
a002 2025-03-25 14:37:43 +08:00
parent 90a91243a2
commit 4b8f90712f

View File

@ -1493,7 +1493,7 @@ class SwooshROnnx(torch.nn.Module):
def SwooshLForward(x: Tensor):
x_offset = x - 4.0
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
return log_sum - 0.08 * x - 0.035
@ -1502,7 +1502,7 @@ def SwooshLForward(x: Tensor):
def SwooshRForward(x: Tensor):
x_offset = x - 1.0
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
return log_sum - 0.08 * x - 0.313261687