Fix chunk issue for sherpa (#1316)

This commit is contained in:
Erwan Zerhouni 2023-10-18 10:07:10 +02:00 committed by GitHub
parent d2bd0933b1
commit 807816fec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,28 +17,33 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import math import math
import random
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import logging
import torch import torch
import random
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer, Balancer,
BiasNorm, BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d, ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear, Dropout2,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. FloatLike,
ScheduledFloat,
Whiten, Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. convert_num_channels,
limit_param_value,
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1) s, x, y = x.chunk(3, dim=2)
# s will go through tanh. # s will go through tanh.
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1) s, x, y = x.chunk(3, dim=2)
# s will go through tanh. # s will go through tanh.
s = self.tanh(s) s = self.tanh(s)
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels) x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=-1) x, s = x.chunk(2, dim=2)
s = self.balancer1(s) s = self.balancer1(s)
s = self.sigmoid(s) s = self.sigmoid(s)
x = self.activation1(x) # identity. x = self.activation1(x) # identity.