mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
disable where
This commit is contained in:
parent
0e749e4eb0
commit
615e5206e1
@ -1493,7 +1493,7 @@ class SwooshROnnx(torch.nn.Module):
|
||||
def SwooshLForward(x: Tensor):
|
||||
x_offset = x - 4.0
|
||||
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
||||
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
||||
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
||||
return log_sum - 0.08 * x - 0.035
|
||||
|
||||
|
||||
@ -1502,7 +1502,7 @@ def SwooshLForward(x: Tensor):
|
||||
def SwooshRForward(x: Tensor):
|
||||
x_offset = x - 1.0
|
||||
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
||||
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
||||
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
||||
return log_sum - 0.08 * x - 0.313261687
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user