Fix typo and reformatted zipformer

This commit is contained in:
Yifan Yang 2023-06-02 12:42:54 +08:00
parent a98f6b27a4
commit 7eb957e993
9 changed files with 35 additions and 33 deletions

View File

@ -17,7 +17,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer

View File

@ -160,6 +160,7 @@ from typing import List, Tuple
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_params, get_transducer_model
@ -170,7 +171,6 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
from scaling_converter import convert_scaled_to_non_scaled
def get_parser():

View File

@ -43,13 +43,9 @@ from pathlib import Path
import sentencepiece as spm
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints_with_averaged_model,
find_checkpoints,
)
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
def get_parser():

View File

@ -19,9 +19,9 @@ import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class Transducer(nn.Module):

View File

@ -120,10 +120,11 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from icefall.utils import make_pad_mask
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -22,24 +22,24 @@ Usage: ./zipformer/profile.py
import argparse
import logging
from typing import Tuple
import sentencepiece as spm
import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm
from torch import Tensor, nn
from train import (
add_model_arguments,
get_encoder_embed,
get_encoder_model,
get_joiner_model,
add_model_arguments,
get_params,
)
from zipformer import BypassModule
from icefall.profiler import get_model_profile
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -15,15 +15,16 @@
# limitations under the License.
from typing import Optional, Tuple, Union
import logging
import k2
from torch.cuda.amp import custom_fwd, custom_bwd
import random
import torch
import math
import random
from typing import Optional, Tuple, Union
import k2
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
class PiecewiseLinear(object):

View File

@ -16,11 +16,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import warnings
from typing import Tuple
import torch
from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
@ -34,6 +33,7 @@ from scaling import (
SwooshR,
Whiten,
)
from torch import Tensor, nn
class ConvNeXt(nn.Module):

View File

@ -17,28 +17,33 @@
# limitations under the License.
import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from torch import Tensor, nn