Fix chunk issue for sherpa

This commit is contained in:
Erwan 2023-10-18 08:23:40 +02:00
parent d2bd0933b1
commit e070a737c0

View File

@ -17,28 +17,33 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import math import math
import random
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import logging
import torch import torch
import random
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer, Balancer,
BiasNorm, BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d, ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear, Dropout2,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. FloatLike,
ScheduledFloat,
Whiten, Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. convert_num_channels,
limit_param_value,
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1) s, x, y = x.chunk(3, dim=2)
# s will go through tanh. # s will go through tanh.
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1) s, x, y = x.chunk(3, dim=2)
# s will go through tanh. # s will go through tanh.
s = self.tanh(s) s = self.tanh(s)
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels) x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=-1) x, s = x.chunk(2, dim=2)
s = self.balancer1(s) s = self.balancer1(s)
s = self.sigmoid(s) s = self.sigmoid(s)
x = self.activation1(x) # identity. x = self.activation1(x) # identity.