mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
black formatted
This commit is contained in:
parent
fde8a2ff65
commit
72d947387d
@ -22,10 +22,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import pprint
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
import pprint
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
|
||||||
assert len(x.shape) == 3, "Input is not a 3D tensor!"
|
assert len(x.shape) == 3, "Input is not a 3D tensor!"
|
||||||
y = self.exponential_moving_avg(x)
|
y = self.exponential_moving_avg(x)
|
||||||
y = y.permute(0, 2, 1) # make channel last for squeeze op
|
y = y.permute(0, 2, 1) # make channel last for squeeze op
|
||||||
|
Loading…
x
Reference in New Issue
Block a user