From 4b8f90712f8d19b7847abc936b479c3b6689a111 Mon Sep 17 00:00:00 2001 From: a002 Date: Tue, 25 Mar 2025 14:37:43 +0800 Subject: [PATCH] enable torch.where --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 34bae5c57..d345c2931 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1493,7 +1493,7 @@ class SwooshROnnx(torch.nn.Module): def SwooshLForward(x: Tensor): x_offset = x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - # log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.035 @@ -1502,7 +1502,7 @@ def SwooshLForward(x: Tensor): def SwooshRForward(x: Tensor): x_offset = x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - # log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.313261687