mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix typo and reformatted zipformer
This commit is contained in:
parent
a98f6b27a4
commit
7eb957e993
@ -17,7 +17,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from scaling import Balancer
|
from scaling import Balancer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -160,6 +160,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -170,7 +171,6 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
|||||||
@ -43,13 +43,9 @@ from pathlib import Path
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||||
average_checkpoints_with_averaged_model,
|
|
||||||
find_checkpoints,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
|||||||
@ -19,9 +19,9 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
|
|||||||
@ -120,10 +120,11 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@ -22,24 +22,24 @@ Usage: ./zipformer/profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
from icefall.profiler import get_model_profile
|
|
||||||
from scaling import BiasNorm
|
from scaling import BiasNorm
|
||||||
|
from torch import Tensor, nn
|
||||||
from train import (
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
get_encoder_embed,
|
get_encoder_embed,
|
||||||
get_encoder_model,
|
get_encoder_model,
|
||||||
get_joiner_model,
|
get_joiner_model,
|
||||||
add_model_arguments,
|
|
||||||
get_params,
|
get_params,
|
||||||
)
|
)
|
||||||
from zipformer import BypassModule
|
from zipformer import BypassModule
|
||||||
|
|
||||||
|
from icefall.profiler import get_model_profile
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@ -15,15 +15,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
import logging
|
import logging
|
||||||
import k2
|
|
||||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
class PiecewiseLinear(object):
|
class PiecewiseLinear(object):
|
||||||
|
|||||||
@ -16,11 +16,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
@ -34,6 +33,7 @@ from scaling import (
|
|||||||
SwooshR,
|
SwooshR,
|
||||||
Whiten,
|
Whiten,
|
||||||
)
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(nn.Module):
|
||||||
|
|||||||
@ -17,28 +17,33 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
import logging
|
|
||||||
import torch
|
import torch
|
||||||
import random
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
|
)
|
||||||
|
from scaling import (
|
||||||
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
|
)
|
||||||
|
from scaling import (
|
||||||
|
ActivationDropoutAndLinear,
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
Dropout2,
|
|
||||||
ChunkCausalDepthwiseConv1d,
|
ChunkCausalDepthwiseConv1d,
|
||||||
ActivationDropoutAndLinear,
|
Dropout2,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
FloatLike,
|
||||||
|
ScheduledFloat,
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
convert_num_channels,
|
||||||
|
limit_param_value,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
ScheduledFloat,
|
|
||||||
FloatLike,
|
|
||||||
limit_param_value,
|
|
||||||
convert_num_channels,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user