mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
black
and isort
formatted the project
This commit is contained in:
parent
d5566b8c08
commit
2f4a6e95e6
@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -70,9 +70,9 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
|
@ -76,6 +76,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -42,12 +42,10 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import keywords_search
|
||||||
keywords_search,
|
from lhotse.cut import Cut
|
||||||
)
|
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from lhotse.cut import Cut
|
|
||||||
from icefall import ContextGraph
|
from icefall import ContextGraph
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -76,6 +76,20 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
add_training_arguments,
|
||||||
|
compute_loss,
|
||||||
|
compute_validation_loss,
|
||||||
|
display_and_save_batch,
|
||||||
|
get_adjusted_batch_count,
|
||||||
|
get_model,
|
||||||
|
get_params,
|
||||||
|
load_checkpoint_if_available,
|
||||||
|
save_checkpoint,
|
||||||
|
scan_pessimistic_batches_for_oom,
|
||||||
|
set_batch_count,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import remove_checkpoints
|
from icefall.checkpoint import remove_checkpoints
|
||||||
@ -95,21 +109,6 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
from train import (
|
|
||||||
add_model_arguments,
|
|
||||||
add_training_arguments,
|
|
||||||
compute_loss,
|
|
||||||
compute_validation_loss,
|
|
||||||
display_and_save_batch,
|
|
||||||
get_adjusted_batch_count,
|
|
||||||
get_model,
|
|
||||||
get_params,
|
|
||||||
load_checkpoint_if_available,
|
|
||||||
save_checkpoint,
|
|
||||||
scan_pessimistic_batches_for_oom,
|
|
||||||
set_batch_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,8 +24,7 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from train import get_ctc_model, get_params
|
||||||
from train import get_params, get_ctc_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_model():
|
def test_model():
|
||||||
|
@ -59,9 +59,9 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from emformer import Emformer
|
from emformer import Emformer
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -39,7 +39,7 @@ Usage of this script:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -47,7 +47,6 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.supervision import AlignmentItem
|
|
||||||
from lhotse.serialization import SequentialJsonlWriter
|
from lhotse.serialization import SequentialJsonlWriter
|
||||||
|
from lhotse.supervision import AlignmentItem
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
from icefall import is_module_available
|
from icefall import is_module_available
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from train import get_encoder_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -76,8 +76,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import BasicNorm, DoubleSwish
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from scaling import BasicNorm, DoubleSwish
|
|
||||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -82,8 +82,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
|
|
||||||
# The force alignment problem can be formulated as finding
|
# The force alignment problem can be formulated as finding
|
||||||
|
@ -107,9 +107,6 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
|
||||||
from gigaspeech import GigaSpeechAsrDataModule
|
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -120,6 +117,9 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from gigaspeech import GigaSpeechAsrDataModule
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
@ -65,16 +65,15 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import BasicNorm, DoubleSwish
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from scaling import BasicNorm, DoubleSwish
|
|
||||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
@ -118,8 +118,8 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import (
|
from scaling import ActivationBalancer, ScaledConv1d
|
||||||
ActivationBalancer,
|
|
||||||
ScaledConv1d,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LConv(nn.Module):
|
class LConv(nn.Module):
|
||||||
|
@ -52,7 +52,7 @@ import onnxruntime as ort
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
import ncnn
|
import ncnn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
layer_list = []
|
layer_list = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ import ncnn
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
from ncnn_custom_layer import RegisterCustomLayers
|
from ncnn_custom_layer import RegisterCustomLayers
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import pprint
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
import pprint
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -22,9 +22,9 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
|
@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
from icefall.profiler import get_model_profile
|
|
||||||
from scaling import BiasNorm
|
from scaling import BiasNorm
|
||||||
|
from torch import Tensor, nn
|
||||||
from train import (
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
get_encoder_embed,
|
get_encoder_embed,
|
||||||
get_encoder_model,
|
get_encoder_model,
|
||||||
get_joiner_model,
|
get_joiner_model,
|
||||||
add_model_arguments,
|
|
||||||
get_params,
|
get_params,
|
||||||
)
|
)
|
||||||
from zipformer import BypassModule
|
from zipformer import BypassModule
|
||||||
|
|
||||||
|
from icefall.profiler import get_model_profile
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -77,11 +77,10 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from k2 import SymbolTable
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
from k2 import SymbolTable
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
from typing import Dict
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
from typing import Dict
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
from typing import Dict
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -15,15 +15,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
import logging
|
import logging
|
||||||
import k2
|
|
||||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -16,11 +16,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
@ -34,6 +33,7 @@ from scaling import (
|
|||||||
SwooshR,
|
SwooshR,
|
||||||
Whiten,
|
Whiten,
|
||||||
)
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(nn.Module):
|
||||||
|
@ -858,7 +858,9 @@ def main():
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
import pdb; pdb.set_trace()
|
import pdb
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -877,9 +879,13 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device), strict=False)
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
elif params.avg == 1:
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False)
|
load_checkpoint(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
@ -888,7 +894,9 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device), strict=False)
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
@ -917,7 +925,7 @@ def main():
|
|||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
strict=False
|
strict=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0, params.avg
|
assert params.avg > 0, params.avg
|
||||||
@ -936,7 +944,7 @@ def main():
|
|||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
strict=False
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -121,7 +121,7 @@ from beam_search import (
|
|||||||
modified_beam_search_lm_shallow_fusion,
|
modified_beam_search_lm_shallow_fusion,
|
||||||
modified_beam_search_LODR,
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
|
@ -72,7 +72,7 @@ import torch.nn as nn
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
|
@ -77,11 +77,10 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from k2 import SymbolTable
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
from k2 import SymbolTable
|
|
||||||
|
|
||||||
conversational_filler = [
|
conversational_filler = [
|
||||||
"UH",
|
"UH",
|
||||||
@ -182,6 +181,7 @@ def get_parser():
|
|||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def post_processing(
|
def post_processing(
|
||||||
results: List[Tuple[str, List[str], List[str]]],
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
) -> List[Tuple[str, List[str], List[str]]]:
|
) -> List[Tuple[str, List[str], List[str]]]:
|
||||||
@ -192,6 +192,7 @@ def post_processing(
|
|||||||
new_results.append((key, new_ref, new_hyp))
|
new_results.append((key, new_ref, new_hyp))
|
||||||
return new_results
|
return new_results
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
model: OnnxModel, token_table: SymbolTable, batch: dict
|
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||||
) -> List[List[str]]:
|
) -> List[List[str]]:
|
||||||
|
@ -137,14 +137,14 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--use-adapters",
|
"--use-adapters",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="If use adapter to finetune the model"
|
help="If use adapter to finetune the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-dim",
|
"--adapter-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The bottleneck dimension of the adapter"
|
help="The bottleneck dimension of the adapter",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -1273,7 +1273,11 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))
|
logging.info(
|
||||||
|
"A total of {} trainable parameters ({:.3f}% of the whole model)".format(
|
||||||
|
num_trainable, num_trainable / num_param * 100
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
@ -40,13 +40,13 @@ from scaling import (
|
|||||||
Dropout2,
|
Dropout2,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
|
SwooshL,
|
||||||
|
SwooshR,
|
||||||
Whiten,
|
Whiten,
|
||||||
convert_num_channels,
|
convert_num_channels,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
SwooshL,
|
|
||||||
SwooshR,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -601,8 +601,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
bypass_skip_rate: FloatLike = ScheduledFloat(
|
bypass_skip_rate: FloatLike = ScheduledFloat(
|
||||||
(0.0, 0.5), (4000.0, 0.02), default=0
|
(0.0, 0.5), (4000.0, 0.02), default=0
|
||||||
),
|
),
|
||||||
use_adapters: bool=False,
|
use_adapters: bool = False,
|
||||||
adapter_dim: int=16,
|
adapter_dim: int = 16,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer2EncoderLayer, self).__init__()
|
super(Zipformer2EncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -2488,8 +2488,8 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
class AdapterModule(nn.Module):
|
class AdapterModule(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embed_dim: int=384,
|
embed_dim: int = 384,
|
||||||
bottleneck_dim: int=16,
|
bottleneck_dim: int = 16,
|
||||||
):
|
):
|
||||||
# The simplest adapter
|
# The simplest adapter
|
||||||
super(AdapterModule, self).__init__()
|
super(AdapterModule, self).__init__()
|
||||||
|
@ -5,9 +5,9 @@ This file prints the text field of supervisions from cutset to the console
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from lhotse import load_manifest_lazy
|
from lhotse import load_manifest_lazy
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -5,7 +5,6 @@ This file generates words.txt from the given transcript file.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import SwitchBoardAsrDataModule
|
from asr_datamodule import SwitchBoardAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from sclite_scoring import asr_text_post_processing
|
from sclite_scoring import asr_text_post_processing
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
|
@ -16,8 +16,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall import smart_byte_decode
|
from icefall import smart_byte_decode
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
|
@ -17,10 +17,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
import lhotse
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lhotse
|
||||||
|
import torch
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
Fbank,
|
Fbank,
|
||||||
@ -29,6 +29,7 @@ from lhotse import (
|
|||||||
fix_manifests,
|
fix_manifests,
|
||||||
validate_recordings_and_supervisions,
|
validate_recordings_and_supervisions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
@ -41,6 +41,7 @@ from prepare_lang import (
|
|||||||
write_lexicon,
|
write_lexicon,
|
||||||
write_mapping,
|
write_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.utils import text_to_pinyin
|
from icefall.utils import text_to_pinyin
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -30,9 +30,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import keywords_search
|
||||||
keywords_search,
|
|
||||||
)
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
@ -87,6 +87,19 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
add_training_arguments,
|
||||||
|
compute_validation_loss,
|
||||||
|
display_and_save_batch,
|
||||||
|
get_adjusted_batch_count,
|
||||||
|
get_model,
|
||||||
|
get_params,
|
||||||
|
load_checkpoint_if_available,
|
||||||
|
save_checkpoint,
|
||||||
|
scan_pessimistic_batches_for_oom,
|
||||||
|
set_batch_count,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
@ -109,21 +122,6 @@ from icefall.utils import (
|
|||||||
text_to_pinyin,
|
text_to_pinyin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from train import (
|
|
||||||
add_model_arguments,
|
|
||||||
add_training_arguments,
|
|
||||||
compute_validation_loss,
|
|
||||||
display_and_save_batch,
|
|
||||||
get_adjusted_batch_count,
|
|
||||||
get_model,
|
|
||||||
get_params,
|
|
||||||
load_checkpoint_if_available,
|
|
||||||
save_checkpoint,
|
|
||||||
scan_pessimistic_batches_for_oom,
|
|
||||||
set_batch_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +99,6 @@ from icefall.utils import (
|
|||||||
text_to_pinyin,
|
text_to_pinyin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,9 +18,8 @@ you can use ./export.py --jit 1
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
import math
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
|
|
||||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||||
SPACE = chr(32)
|
SPACE = chr(32)
|
||||||
SPACE_ESCAPE = chr(9601)
|
SPACE_ESCAPE = chr(9601)
|
||||||
|
@ -8,12 +8,12 @@ The lang_dir should contain the following files:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
class Lexicon:
|
class Lexicon:
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
@ -5,14 +5,15 @@
|
|||||||
|
|
||||||
# This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py
|
# This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from functools import partial
|
|
||||||
from typing import List, Optional
|
|
||||||
from collections import OrderedDict
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
Tensor = torch.Tensor
|
Tensor = torch.Tensor
|
||||||
|
|
||||||
|
@ -5,16 +5,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
from model import RnnLmModel
|
from model import RnnLmModel
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from train import get_params
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, str2bool
|
||||||
from typing import Dict
|
|
||||||
from train import get_params
|
|
||||||
|
|
||||||
|
|
||||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
@ -28,8 +28,6 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pypinyin import pinyin, lazy_pinyin
|
|
||||||
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals
|
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
|
||||||
@ -40,6 +38,8 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from pypinyin import lazy_pinyin, pinyin
|
||||||
|
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints
|
from icefall.checkpoint import average_checkpoints
|
||||||
|
Loading…
x
Reference in New Issue
Block a user