disable where

This commit is contained in:
Fangjun Kuang 2025-03-03 17:40:27 +08:00
parent 0e749e4eb0
commit 615e5206e1

View File

@ -1493,7 +1493,7 @@ class SwooshROnnx(torch.nn.Module):
def SwooshLForward(x: Tensor):
x_offset = x - 4.0
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
return log_sum - 0.08 * x - 0.035
@ -1502,7 +1502,7 @@ def SwooshLForward(x: Tensor):
def SwooshRForward(x: Tensor):
x_offset = x - 1.0
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
return log_sum - 0.08 * x - 0.313261687