mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
black
and isort
formatted the project
This commit is contained in:
parent
d5566b8c08
commit
2f4a6e95e6
@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -70,9 +70,9 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from tokenizer import Tokenizer
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Similar text filtering and normalization procedure as in:
|
||||
|
@ -76,6 +76,7 @@ from beam_search import (
|
||||
)
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -42,12 +42,10 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
keywords_search,
|
||||
)
|
||||
from beam_search import keywords_search
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from lhotse.cut import Cut
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -76,6 +76,20 @@ from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
add_training_arguments,
|
||||
compute_loss,
|
||||
compute_validation_loss,
|
||||
display_and_save_batch,
|
||||
get_adjusted_batch_count,
|
||||
get_model,
|
||||
get_params,
|
||||
load_checkpoint_if_available,
|
||||
save_checkpoint,
|
||||
scan_pessimistic_batches_for_oom,
|
||||
set_batch_count,
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
@ -95,21 +109,6 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
add_training_arguments,
|
||||
compute_loss,
|
||||
compute_validation_loss,
|
||||
display_and_save_batch,
|
||||
get_adjusted_batch_count,
|
||||
get_model,
|
||||
get_params,
|
||||
load_checkpoint_if_available,
|
||||
save_checkpoint,
|
||||
scan_pessimistic_batches_for_oom,
|
||||
set_batch_count,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
|
@ -24,8 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from train import get_params, get_ctc_model
|
||||
from train import get_ctc_model, get_params
|
||||
|
||||
|
||||
def test_model():
|
||||
|
@ -59,9 +59,9 @@ import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
from emformer import Emformer
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -39,7 +39,7 @@ Usage of this script:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
@ -47,7 +47,6 @@ import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch.multiprocessing as mp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.supervision import AlignmentItem
|
||||
from lhotse.serialization import SequentialJsonlWriter
|
||||
from lhotse.supervision import AlignmentItem
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_encoder_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from train import get_encoder_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -76,8 +76,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from librispeech import LibriSpeech
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -82,8 +82,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -20,7 +20,6 @@ from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
|
||||
# The force alignment problem can be formulated as finding
|
||||
|
@ -107,9 +107,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
@ -120,6 +117,9 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
@ -65,16 +65,15 @@ from typing import Dict, List
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -24,7 +24,6 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
@ -118,8 +118,8 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
ScaledConv1d,
|
||||
)
|
||||
from scaling import ActivationBalancer, ScaledConv1d
|
||||
|
||||
|
||||
class LConv(nn.Module):
|
||||
|
@ -52,7 +52,7 @@ import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
|
@ -4,7 +4,6 @@
|
||||
import ncnn
|
||||
import numpy as np
|
||||
|
||||
|
||||
layer_list = []
|
||||
|
||||
|
||||
|
@ -42,7 +42,6 @@ import ncnn
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
from ncnn_custom_layer import RegisterCustomLayers
|
||||
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import pprint
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pprint
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -22,9 +22,9 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
|
@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BiasNorm
|
||||
from torch import Tensor, nn
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
get_encoder_embed,
|
||||
get_encoder_model,
|
||||
get_joiner_model,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
)
|
||||
from zipformer import BypassModule
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -77,11 +77,10 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from k2 import SymbolTable
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
from k2 import SymbolTable
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
from typing import Dict
|
||||
import kaldifst
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
from typing import Dict
|
||||
import kaldifst
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
from typing import Dict
|
||||
import kaldifst
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
@ -15,15 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
import logging
|
||||
import k2
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -16,11 +16,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
@ -34,6 +33,7 @@ from scaling import (
|
||||
SwooshR,
|
||||
Whiten,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
|
@ -858,7 +858,9 @@ def main():
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
import pdb; pdb.set_trace()
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
@ -877,9 +879,13 @@ def main():
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device), strict=False)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False)
|
||||
load_checkpoint(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
|
||||
)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
@ -888,7 +894,9 @@ def main():
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device), strict=False)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
@ -917,7 +925,7 @@ def main():
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
@ -936,7 +944,7 @@ def main():
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False
|
||||
strict=False,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
|
@ -121,7 +121,7 @@ from beam_search import (
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
||||
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
|
@ -72,7 +72,7 @@ import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
||||
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
|
@ -77,11 +77,10 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from k2 import SymbolTable
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
from k2 import SymbolTable
|
||||
|
||||
conversational_filler = [
|
||||
"UH",
|
||||
@ -182,6 +181,7 @@ def get_parser():
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def post_processing(
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
) -> List[Tuple[str, List[str], List[str]]]:
|
||||
@ -192,6 +192,7 @@ def post_processing(
|
||||
new_results.append((key, new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||
) -> List[List[str]]:
|
||||
|
@ -137,14 +137,14 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
||||
"--use-adapters",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If use adapter to finetune the model"
|
||||
help="If use adapter to finetune the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adapter-dim",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The bottleneck dimension of the adapter"
|
||||
help="The bottleneck dimension of the adapter",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -1273,7 +1273,11 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))
|
||||
logging.info(
|
||||
"A total of {} trainable parameters ({:.3f}% of the whole model)".format(
|
||||
num_trainable, num_trainable / num_param * 100
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
|
@ -40,13 +40,13 @@ from scaling import (
|
||||
Dropout2,
|
||||
FloatLike,
|
||||
ScheduledFloat,
|
||||
SwooshL,
|
||||
SwooshR,
|
||||
Whiten,
|
||||
convert_num_channels,
|
||||
limit_param_value,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
SwooshL,
|
||||
SwooshR,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -601,8 +601,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
bypass_skip_rate: FloatLike = ScheduledFloat(
|
||||
(0.0, 0.5), (4000.0, 0.02), default=0
|
||||
),
|
||||
use_adapters: bool=False,
|
||||
adapter_dim: int=16,
|
||||
use_adapters: bool = False,
|
||||
adapter_dim: int = 16,
|
||||
) -> None:
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -2488,8 +2488,8 @@ def _test_zipformer_main(causal: bool = False):
|
||||
class AdapterModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int=384,
|
||||
bottleneck_dim: int=16,
|
||||
embed_dim: int = 384,
|
||||
bottleneck_dim: int = 16,
|
||||
):
|
||||
# The simplest adapter
|
||||
super(AdapterModule, self).__init__()
|
||||
|
@ -5,9 +5,9 @@ This file prints the text field of supervisions from cutset to the console
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -5,7 +5,6 @@ This file generates words.txt from the given transcript file.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
@ -29,7 +29,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import SwitchBoardAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from sclite_scoring import asr_text_post_processing
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
|
@ -16,8 +16,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
|
||||
|
||||
|
@ -19,9 +19,9 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
|
@ -17,10 +17,10 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import lhotse
|
||||
from pathlib import Path
|
||||
|
||||
import lhotse
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
@ -29,6 +29,7 @@ from lhotse import (
|
||||
fix_manifests,
|
||||
validate_recordings_and_supervisions,
|
||||
)
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
|
@ -41,6 +41,7 @@ from prepare_lang import (
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.utils import text_to_pinyin
|
||||
|
||||
|
||||
|
@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -30,9 +30,7 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
keywords_search,
|
||||
)
|
||||
from beam_search import keywords_search
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
|
@ -87,6 +87,19 @@ from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
add_training_arguments,
|
||||
compute_validation_loss,
|
||||
display_and_save_batch,
|
||||
get_adjusted_batch_count,
|
||||
get_model,
|
||||
get_params,
|
||||
load_checkpoint_if_available,
|
||||
save_checkpoint,
|
||||
scan_pessimistic_batches_for_oom,
|
||||
set_batch_count,
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
@ -109,21 +122,6 @@ from icefall.utils import (
|
||||
text_to_pinyin,
|
||||
)
|
||||
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
add_training_arguments,
|
||||
compute_validation_loss,
|
||||
display_and_save_batch,
|
||||
get_adjusted_batch_count,
|
||||
get_model,
|
||||
get_params,
|
||||
load_checkpoint_if_available,
|
||||
save_checkpoint,
|
||||
scan_pessimistic_batches_for_oom,
|
||||
set_batch_count,
|
||||
)
|
||||
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
|
@ -99,7 +99,6 @@ from icefall.utils import (
|
||||
text_to_pinyin,
|
||||
)
|
||||
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
|
@ -18,9 +18,8 @@ you can use ./export.py --jit 1
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import List
|
||||
import math
|
||||
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
|
@ -8,7 +8,6 @@
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
|
||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
|
@ -8,12 +8,12 @@ The lang_dir should contain the following files:
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import kaldifst
|
||||
import re
|
||||
|
||||
|
||||
class Lexicon:
|
||||
|
@ -18,7 +18,7 @@
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
@ -5,14 +5,15 @@
|
||||
|
||||
# This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
@ -5,16 +5,16 @@
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from model import RnnLmModel
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from train import get_params
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
from typing import Dict
|
||||
from train import get_params
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
|
@ -28,8 +28,6 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from pypinyin import pinyin, lazy_pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals
|
||||
from shutil import copyfile
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
@ -40,6 +38,8 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from pypinyin import lazy_pinyin, pinyin
|
||||
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import average_checkpoints
|
||||
|
Loading…
x
Reference in New Issue
Block a user