Fix chunk issue for sherpa (#1316)

This commit is contained in:
Erwan Zerhouni 2023-10-18 10:07:10 +02:00 committed by GitHub
parent d2bd0933b1
commit 807816fec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,28 +17,33 @@
# limitations under the License.
import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from torch import Tensor, nn
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1)
s, x, y = x.chunk(3, dim=2)
# s will go through tanh.
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1)
s, x, y = x.chunk(3, dim=2)
# s will go through tanh.
s = self.tanh(s)
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=-1)
x, s = x.chunk(2, dim=2)
s = self.balancer1(s)
s = self.sigmoid(s)
x = self.activation1(x) # identity.