mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix chunk issue for sherpa (#1316)
This commit is contained in:
parent
d2bd0933b1
commit
807816fec0
@ -17,28 +17,33 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
)
|
||||
from scaling import (
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
)
|
||||
from scaling import (
|
||||
ActivationDropoutAndLinear,
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout2,
|
||||
ChunkCausalDepthwiseConv1d,
|
||||
ActivationDropoutAndLinear,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
Dropout2,
|
||||
FloatLike,
|
||||
ScheduledFloat,
|
||||
Whiten,
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
convert_num_channels,
|
||||
limit_param_value,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
convert_num_channels,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
|
||||
(seq_len, batch_size, _) = x.shape
|
||||
hidden_channels = self.hidden_channels
|
||||
|
||||
s, x, y = x.chunk(3, dim=-1)
|
||||
s, x, y = x.chunk(3, dim=2)
|
||||
|
||||
# s will go through tanh.
|
||||
|
||||
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
|
||||
(seq_len, batch_size, _) = x.shape
|
||||
hidden_channels = self.hidden_channels
|
||||
|
||||
s, x, y = x.chunk(3, dim=-1)
|
||||
s, x, y = x.chunk(3, dim=2)
|
||||
|
||||
# s will go through tanh.
|
||||
s = self.tanh(s)
|
||||
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
x, s = x.chunk(2, dim=2)
|
||||
s = self.balancer1(s)
|
||||
s = self.sigmoid(s)
|
||||
x = self.activation1(x) # identity.
|
||||
|
Loading…
x
Reference in New Issue
Block a user