mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
disable where
This commit is contained in:
parent
0e749e4eb0
commit
615e5206e1
@ -1493,7 +1493,7 @@ class SwooshROnnx(torch.nn.Module):
|
|||||||
def SwooshLForward(x: Tensor):
|
def SwooshLForward(x: Tensor):
|
||||||
x_offset = x - 4.0
|
x_offset = x - 4.0
|
||||||
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
||||||
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
||||||
return log_sum - 0.08 * x - 0.035
|
return log_sum - 0.08 * x - 0.035
|
||||||
|
|
||||||
|
|
||||||
@ -1502,7 +1502,7 @@ def SwooshLForward(x: Tensor):
|
|||||||
def SwooshRForward(x: Tensor):
|
def SwooshRForward(x: Tensor):
|
||||||
x_offset = x - 1.0
|
x_offset = x - 1.0
|
||||||
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
||||||
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
# log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
||||||
return log_sum - 0.08 * x - 0.313261687
|
return log_sum - 0.08 * x - 0.313261687
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user