mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Fix chunk issue for sherpa (#1316)
This commit is contained in:
parent
d2bd0933b1
commit
807816fec0
@ -17,28 +17,33 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
import logging
|
|
||||||
import torch
|
import torch
|
||||||
import random
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
|
)
|
||||||
|
from scaling import (
|
||||||
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
|
)
|
||||||
|
from scaling import (
|
||||||
|
ActivationDropoutAndLinear,
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
Dropout2,
|
|
||||||
ChunkCausalDepthwiseConv1d,
|
ChunkCausalDepthwiseConv1d,
|
||||||
ActivationDropoutAndLinear,
|
Dropout2,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
FloatLike,
|
||||||
|
ScheduledFloat,
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
convert_num_channels,
|
||||||
|
limit_param_value,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
ScheduledFloat,
|
|
||||||
FloatLike,
|
|
||||||
limit_param_value,
|
|
||||||
convert_num_channels,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
|
|||||||
(seq_len, batch_size, _) = x.shape
|
(seq_len, batch_size, _) = x.shape
|
||||||
hidden_channels = self.hidden_channels
|
hidden_channels = self.hidden_channels
|
||||||
|
|
||||||
s, x, y = x.chunk(3, dim=-1)
|
s, x, y = x.chunk(3, dim=2)
|
||||||
|
|
||||||
# s will go through tanh.
|
# s will go through tanh.
|
||||||
|
|
||||||
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
|
|||||||
(seq_len, batch_size, _) = x.shape
|
(seq_len, batch_size, _) = x.shape
|
||||||
hidden_channels = self.hidden_channels
|
hidden_channels = self.hidden_channels
|
||||||
|
|
||||||
s, x, y = x.chunk(3, dim=-1)
|
s, x, y = x.chunk(3, dim=2)
|
||||||
|
|
||||||
# s will go through tanh.
|
# s will go through tanh.
|
||||||
s = self.tanh(s)
|
s = self.tanh(s)
|
||||||
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||||
|
|
||||||
x, s = x.chunk(2, dim=-1)
|
x, s = x.chunk(2, dim=2)
|
||||||
s = self.balancer1(s)
|
s = self.balancer1(s)
|
||||||
s = self.sigmoid(s)
|
s = self.sigmoid(s)
|
||||||
x = self.activation1(x) # identity.
|
x = self.activation1(x) # identity.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user