mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix typo and reformatted zipformer
This commit is contained in:
parent
a98f6b27a4
commit
7eb957e993
@ -17,7 +17,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scaling import Balancer
|
||||
|
||||
|
||||
|
||||
@ -160,6 +160,7 @@ from typing import List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
@ -170,7 +171,6 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, str2bool
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
||||
@ -43,13 +43,9 @@ from pathlib import Path
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
)
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
||||
@ -19,9 +19,9 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
|
||||
@ -120,10 +120,11 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from icefall.utils import make_pad_mask
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@ -22,24 +22,24 @@ Usage: ./zipformer/profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BiasNorm
|
||||
from torch import Tensor, nn
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
get_encoder_embed,
|
||||
get_encoder_model,
|
||||
get_joiner_model,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
)
|
||||
from zipformer import BypassModule
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@ -15,15 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
import logging
|
||||
import k2
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
|
||||
@ -16,11 +16,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
@ -34,6 +33,7 @@ from scaling import (
|
||||
SwooshR,
|
||||
Whiten,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
|
||||
@ -17,28 +17,33 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
)
|
||||
from scaling import (
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
)
|
||||
from scaling import (
|
||||
ActivationDropoutAndLinear,
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout2,
|
||||
ChunkCausalDepthwiseConv1d,
|
||||
ActivationDropoutAndLinear,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
Dropout2,
|
||||
FloatLike,
|
||||
ScheduledFloat,
|
||||
Whiten,
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
convert_num_channels,
|
||||
limit_param_value,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
convert_num_channels,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user