Fix typo and reformatted zipformer

This commit is contained in:
Yifan Yang 2023-06-02 12:42:54 +08:00
parent a98f6b27a4
commit 7eb957e993
9 changed files with 35 additions and 33 deletions

View File

@ -17,7 +17,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from scaling import Balancer from scaling import Balancer

View File

@ -160,6 +160,7 @@ from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -170,7 +171,6 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.utils import make_pad_mask, str2bool from icefall.utils import make_pad_mask, str2bool
from scaling_converter import convert_scaled_to_non_scaled
def get_parser(): def get_parser():

View File

@ -43,13 +43,9 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
average_checkpoints_with_averaged_model,
find_checkpoints,
)
def get_parser(): def get_parser():

View File

@ -19,9 +19,9 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class Transducer(nn.Module): class Transducer(nn.Module):

View File

@ -120,10 +120,11 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from icefall.utils import make_pad_mask
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import make_pad_mask
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -22,24 +22,24 @@ Usage: ./zipformer/profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm from scaling import BiasNorm
from torch import Tensor, nn
from train import ( from train import (
add_model_arguments,
get_encoder_embed, get_encoder_embed,
get_encoder_model, get_encoder_model,
get_joiner_model, get_joiner_model,
add_model_arguments,
get_params, get_params,
) )
from zipformer import BypassModule from zipformer import BypassModule
from icefall.profiler import get_model_profile
from icefall.utils import make_pad_mask
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -15,15 +15,16 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union
import logging import logging
import k2
from torch.cuda.amp import custom_fwd, custom_bwd
import random
import torch
import math import math
import random
from typing import Optional, Tuple, Union
import k2
import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
class PiecewiseLinear(object): class PiecewiseLinear(object):

View File

@ -16,11 +16,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple
import warnings import warnings
from typing import Tuple
import torch import torch
from torch import Tensor, nn
from scaling import ( from scaling import (
Balancer, Balancer,
BiasNorm, BiasNorm,
@ -34,6 +33,7 @@ from scaling import (
SwooshR, SwooshR,
Whiten, Whiten,
) )
from torch import Tensor, nn
class ConvNeXt(nn.Module): class ConvNeXt(nn.Module):

View File

@ -17,28 +17,33 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import math import math
import random
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import logging
import torch import torch
import random
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer, Balancer,
BiasNorm, BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d, ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear, Dropout2,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. FloatLike,
ScheduledFloat,
Whiten, Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. convert_num_channels,
limit_param_value,
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
) )
from torch import Tensor, nn from torch import Tensor, nn