diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py index 19c518eaf..f04537660 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py @@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index b210430c6..06a0fa96b 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -70,9 +70,9 @@ import logging from pathlib import Path import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled from tokenizer import Tokenizer -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index b6603f80d..a31685211 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -23,6 +23,7 @@ from pathlib import Path from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached + from icefall.utils import str2bool # Similar text filtering and normalization procedure as in: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 72f74c968..ef430302d 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -76,6 +76,7 @@ from beam_search import ( ) from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index aa51036d5..651f20cb6 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -88,7 +88,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py index 7cada8c9d..cb3fd0dc7 100755 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py index 98b003937..0df2ec356 100755 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -42,12 +42,10 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import GigaSpeechAsrDataModule -from beam_search import ( - keywords_search, -) +from beam_search import keywords_search +from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params -from lhotse.cut import Cut from icefall import ContextGraph from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index b8e8802cb..2cd7c868b 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -76,6 +76,20 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from train import ( + add_model_arguments, + add_training_arguments, + compute_loss, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) from icefall import diagnostics from icefall.checkpoint import remove_checkpoints @@ -95,21 +109,6 @@ from icefall.utils import ( str2bool, ) -from train import ( - add_model_arguments, - add_training_arguments, - compute_loss, - compute_validation_loss, - display_and_save_batch, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - scan_pessimistic_batches_for_oom, - set_batch_count, -) - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py index b97b7eed8..aa12d6f83 100755 --- a/egs/librispeech/ASR/conformer_ctc3/test_model.py +++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py @@ -24,8 +24,7 @@ To run this file, do: """ import torch - -from train import get_params, get_ctc_model +from train import get_ctc_model, get_params def test_model(): diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index 1e59e0858..79728afa4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -59,9 +59,9 @@ import onnx import torch import torch.nn as nn from decoder import Decoder +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py index 58f587c91..1deecbfc7 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py @@ -39,7 +39,7 @@ Usage of this script: import argparse import logging import math -from typing import List +from typing import List, Optional import kaldifeat import sentencepiece as spm @@ -47,7 +47,6 @@ import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from torch.nn.utils.rnn import pad_sequence -from typing import Optional, List def get_parser(): diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py index 466253446..f4008c23b 100755 --- a/egs/librispeech/ASR/long_file_recog/recognize.py +++ b/egs/librispeech/ASR/long_file_recog/recognize.py @@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat """ import argparse -import torch.multiprocessing as mp -import torch -import torch.nn as nn import logging from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple - from pathlib import Path +from typing import List, Optional, Tuple import k2 import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn from asr_datamodule import AsrDataModule from beam_search import ( fast_beam_search_one_best, greedy_search_batch, modified_beam_search, ) -from icefall.utils import AttributeDict, convert_timestamp, setup_logger from lhotse import CutSet, load_manifest_lazy from lhotse.cut import Cut -from lhotse.supervision import AlignmentItem from lhotse.serialization import SequentialJsonlWriter +from lhotse.supervision import AlignmentItem + +from icefall.utils import AttributeDict, convert_timestamp, setup_logger def get_parser(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py index c83f38b2a..85e0648d3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py @@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging +import torch from onnx_pretrained import OnnxModel from icefall import is_module_available -import torch - def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py index b844ba613..9762d878c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py @@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py import argparse import logging + import sentencepiece as spm import torch +from train import add_model_arguments, get_encoder_model, get_params from icefall.profiler import get_model_profile -from train import get_encoder_model, add_model_arguments, get_params def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py index 8134d43f8..a235d7b13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py @@ -75,8 +75,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 5ca4173c1..e2c1d6b5b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py index 3b1c72cf1..f8fed9519 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py @@ -76,8 +76,7 @@ import torch import torch.nn as nn from asr_datamodule import AsrDataModule from librispeech import LibriSpeech - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py index 4bf773918..cf0598ca3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py @@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py import argparse import logging +from typing import Tuple + import sentencepiece as spm import torch - -from typing import Tuple +from scaling import BasicNorm, DoubleSwish from torch import Tensor, nn +from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params from icefall.profiler import get_model_profile -from scaling import BasicNorm, DoubleSwish -from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py index 6f26e34b5..b0f76317b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py @@ -82,8 +82,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py index bfb5fe609..ee8196c3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -20,7 +20,6 @@ from typing import List import k2 import torch - from beam_search import Hypothesis, HypothesisList, get_hyps_shape # The force alignment problem can be formulated as finding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index b0e4be0d1..7095c3cc8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -107,9 +107,6 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn - -# from asr_datamodule import LibriSpeechAsrDataModule -from gigaspeech import GigaSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -120,6 +117,9 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech import GigaSpeechAsrDataModule from gigaspeech_scoring import asr_text_post_processing from train import add_model_arguments, get_params, get_transducer_model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py index 37edc0390..3fd14aa47 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py @@ -65,16 +65,15 @@ from typing import Dict, List import sentencepiece as spm import torch - from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) +from icefall.utils import str2bool def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py index 5a068b3b6..1416c6828 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py @@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py import argparse import logging +from typing import Tuple + import sentencepiece as spm import torch - -from typing import Tuple +from scaling import BasicNorm, DoubleSwish from torch import Tensor, nn +from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params from icefall.profiler import get_model_profile -from scaling import BasicNorm, DoubleSwish -from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params def get_parser(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py index 67585ee47..e00281239 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py @@ -75,8 +75,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py index cdf914df3..1f50eb309 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py @@ -24,7 +24,6 @@ To run this file, do: """ import torch - from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index 01ba7b711..e2f08abc6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -118,8 +118,8 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py index a902358ae..2faec7ade 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - ScaledConv1d, -) +from scaling import ActivationBalancer, ScaledConv1d class LConv(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py index 0ff110370..3a16985bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -52,7 +52,7 @@ import onnxruntime as ort import sentencepiece as spm import torch import torchaudio -from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py index 247da0949..07e97bbdb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py @@ -14,6 +14,7 @@ import torch from torch import nn + from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py index 442a0a8af..451c35332 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py @@ -4,7 +4,6 @@ import ncnn import numpy as np - layer_list = [] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py index 999f7e0b4..06127607d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -42,7 +42,6 @@ import ncnn import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - from ncnn_custom_layer import RegisterCustomLayers diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index 6c2bf9ea1..cc4471e2b 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -1,10 +1,11 @@ import argparse import logging import math +import pprint from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import pprint + import k2 import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 4db50b981..1f0f9bfac 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -88,7 +88,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 73009d35c..86da3ab29 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -22,9 +22,9 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class AsrModel(nn.Module): diff --git a/egs/librispeech/ASR/zipformer/my_profile.py b/egs/librispeech/ASR/zipformer/my_profile.py index ca20956fb..7e1fd777a 100755 --- a/egs/librispeech/ASR/zipformer/my_profile.py +++ b/egs/librispeech/ASR/zipformer/my_profile.py @@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py import argparse import logging +from typing import Tuple + import sentencepiece as spm import torch - -from typing import Tuple -from torch import Tensor, nn - -from icefall.utils import make_pad_mask -from icefall.profiler import get_model_profile from scaling import BiasNorm +from torch import Tensor, nn from train import ( + add_model_arguments, get_encoder_embed, get_encoder_model, get_joiner_model, - add_model_arguments, get_params, ) from zipformer import BypassModule +from icefall.profiler import get_model_profile +from icefall.utils import make_pad_mask + def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 356c2a830..449294444 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -77,11 +77,10 @@ from typing import List, Tuple import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py index a77c3bf2a..114490599 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 import argparse import logging import math -from typing import List, Tuple +from typing import Dict, List, Tuple import k2 import kaldifeat -from typing import Dict import kaldifst import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py index 6ef944514..f7d3e5253 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 import argparse import logging import math -from typing import List, Tuple +from typing import Dict, List, Tuple import k2 import kaldifeat -from typing import Dict import kaldifst import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py index ccb3107ea..ebd385364 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 import argparse import logging import math -from typing import List, Tuple +from typing import Dict, List, Tuple import k2 import kaldifeat -from typing import Dict import kaldifst import onnxruntime as ort import torch diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index c0f1e3087..29ac33c02 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -15,15 +15,16 @@ # limitations under the License. -from typing import Optional, Tuple, Union import logging -import k2 -from torch.cuda.amp import custom_fwd, custom_bwd -import random -import torch import math +import random +from typing import Optional, Tuple, Union + +import k2 +import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 8087c1460..360523b8e 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index d16d87bac..b2f769d3f 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -16,11 +16,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import warnings +from typing import Tuple import torch -from torch import Tensor, nn from scaling import ( Balancer, BiasNorm, @@ -34,6 +33,7 @@ from scaling import ( SwooshR, Whiten, ) +from torch import Tensor, nn class ConvNeXt(nn.Module): diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py index bfa4cc230..91533be8d 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -858,7 +858,9 @@ def main(): logging.info("About to create model") model = get_model(params) - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() if not params.use_averaged_model: if params.iter > 0: @@ -877,9 +879,13 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device), strict=False) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False) + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) else: start = params.epoch - params.avg + 1 filenames = [] @@ -888,7 +894,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device), strict=False) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ @@ -917,7 +925,7 @@ def main(): filename_end=filename_end, device=device, ), - strict=False + strict=False, ) else: assert params.avg > 0, params.avg @@ -936,7 +944,7 @@ def main(): filename_end=filename_end, device=device, ), - strict=False + strict=False, ) model.to(device) diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py index 903014f4a..bbc582f50 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -121,7 +121,7 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) -from train import add_model_arguments, add_finetune_arguments, get_model, get_params +from train import add_finetune_arguments, add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( diff --git a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py index a1fc41664..ea29e8159 100755 --- a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py +++ b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py @@ -72,7 +72,7 @@ import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, add_finetune_arguments, get_model, get_params +from train import add_finetune_arguments, add_model_arguments, get_model, get_params from zipformer import Zipformer2 from icefall.checkpoint import ( diff --git a/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py index 000cea163..e3f7ce85a 100755 --- a/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py @@ -77,11 +77,10 @@ from typing import List, Tuple import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable conversational_filler = [ "UH", @@ -182,6 +181,7 @@ def get_parser(): return parser + def post_processing( results: List[Tuple[str, List[str], List[str]]], ) -> List[Tuple[str, List[str], List[str]]]: @@ -192,6 +192,7 @@ def post_processing( new_results.append((key, new_ref, new_hyp)) return new_results + def decode_one_batch( model: OnnxModel, token_table: SymbolTable, batch: dict ) -> List[List[str]]: diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 7f81ddd96..e64c10e7a 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -121,7 +121,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): default=True, help="If true, finetune from a pre-trained checkpoint", ) - + parser.add_argument( "--use-mux", type=str2bool, @@ -137,14 +137,14 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): "--use-adapters", type=str2bool, default=True, - help="If use adapter to finetune the model" + help="If use adapter to finetune the model", ) parser.add_argument( "--adapter-dim", type=int, default=16, - help="The bottleneck dimension of the adapter" + help="The bottleneck dimension of the adapter", ) parser.add_argument( @@ -1273,7 +1273,11 @@ def run(rank, world_size, args): else: p.requires_grad = False - logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100)) + logging.info( + "A total of {} trainable parameters ({:.3f}% of the whole model)".format( + num_trainable, num_trainable / num_param * 100 + ) + ) model.to(device) if world_size > 1: diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index e4e26cd84..4e4695fa5 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -40,13 +40,13 @@ from scaling import ( Dropout2, FloatLike, ScheduledFloat, + SwooshL, + SwooshR, Whiten, convert_num_channels, limit_param_value, penalize_abs_values_gt, softmax, - SwooshL, - SwooshR, ) from torch import Tensor, nn @@ -601,8 +601,8 @@ class Zipformer2EncoderLayer(nn.Module): bypass_skip_rate: FloatLike = ScheduledFloat( (0.0, 0.5), (4000.0, 0.02), default=0 ), - use_adapters: bool=False, - adapter_dim: int=16, + use_adapters: bool = False, + adapter_dim: int = 16, ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim @@ -737,7 +737,7 @@ class Zipformer2EncoderLayer(nn.Module): embed_dim=embed_dim, bottleneck_dim=adapter_dim, ) - + # placed after the 2nd convolution module self.post_conv_adapter = AdapterModule( embed_dim=embed_dim, @@ -2488,8 +2488,8 @@ def _test_zipformer_main(causal: bool = False): class AdapterModule(nn.Module): def __init__( self, - embed_dim: int=384, - bottleneck_dim: int=16, + embed_dim: int = 384, + bottleneck_dim: int = 16, ): # The simplest adapter super(AdapterModule, self).__init__() diff --git a/egs/must_c/ST/local/get_text.py b/egs/must_c/ST/local/get_text.py index 558ab6de8..f7b5816a8 100755 --- a/egs/must_c/ST/local/get_text.py +++ b/egs/must_c/ST/local/get_text.py @@ -5,9 +5,9 @@ This file prints the text field of supervisions from cutset to the console """ import argparse +from pathlib import Path from lhotse import load_manifest_lazy -from pathlib import Path def get_args(): diff --git a/egs/must_c/ST/local/get_words.py b/egs/must_c/ST/local/get_words.py index a61f60860..b32925099 100755 --- a/egs/must_c/ST/local/get_words.py +++ b/egs/must_c/ST/local/get_words.py @@ -5,7 +5,6 @@ This file generates words.txt from the given transcript file. """ import argparse - from pathlib import Path diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 2bbade374..52e501ae1 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -29,7 +29,6 @@ import torch import torch.nn as nn from asr_datamodule import SwitchBoardAsrDataModule from conformer import Conformer - from sclite_scoring import asr_text_post_processing from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py index 6b3316800..13b35980b 100755 --- a/egs/swbd/ASR/local/filter_empty_text.py +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -16,8 +16,8 @@ # limitations under the License. import argparse -from pathlib import Path import logging +from pathlib import Path from typing import List diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py index 8c966a2f6..503cdf4ed 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -45,6 +45,7 @@ import sentencepiece as spm import torch import torchaudio from torch.nn.utils.rnn import pad_sequence + from icefall import smart_byte_decode diff --git a/egs/tedlium3/ASR/zipformer/model.py b/egs/tedlium3/ASR/zipformer/model.py index 90ec7e7aa..65b052ab9 100644 --- a/egs/tedlium3/ASR/zipformer/model.py +++ b/egs/tedlium3/ASR/zipformer/model.py @@ -19,9 +19,9 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class Transducer(nn.Module): diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py index 334a6d023..52da3d6dc 100644 --- a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py +++ b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py @@ -17,10 +17,10 @@ import argparse import logging - -import torch -import lhotse from pathlib import Path + +import lhotse +import torch from lhotse import ( CutSet, Fbank, @@ -29,6 +29,7 @@ from lhotse import ( fix_manifests, validate_recordings_and_supervisions, ) + from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or diff --git a/egs/wenetspeech/ASR/local/prepare_pinyin.py b/egs/wenetspeech/ASR/local/prepare_pinyin.py index ae40f1cdd..112b50b79 100755 --- a/egs/wenetspeech/ASR/local/prepare_pinyin.py +++ b/egs/wenetspeech/ASR/local/prepare_pinyin.py @@ -41,6 +41,7 @@ from prepare_lang import ( write_lexicon, write_mapping, ) + from icefall.utils import text_to_pinyin diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py index ee8252a90..8c192913e 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py @@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 5ed3c6c2c..340a41231 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -30,9 +30,7 @@ import k2 import torch import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule -from beam_search import ( - keywords_search, -) +from beam_search import keywords_search from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 6f34989e2..76df7e8d5 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -87,6 +87,19 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from train import ( + add_model_arguments, + add_training_arguments, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler @@ -109,21 +122,6 @@ from icefall.utils import ( text_to_pinyin, ) -from train import ( - add_model_arguments, - add_training_arguments, - compute_validation_loss, - display_and_save_batch, - get_adjusted_batch_count, - get_model, - get_params, - load_checkpoint_if_available, - save_checkpoint, - scan_pessimistic_batches_for_oom, - set_batch_count, -) - - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 5be34ed99..05acbd6a9 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -99,7 +99,6 @@ from icefall.utils import ( text_to_pinyin, ) - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index e29415ffb..6c643c263 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -18,9 +18,8 @@ you can use ./export.py --jit 1 import argparse import logging -from typing import List import math - +from typing import List import k2 import kaldifeat diff --git a/icefall/byte_utils.py b/icefall/byte_utils.py index 79c1c7545..5f5cc710b 100644 --- a/icefall/byte_utils.py +++ b/icefall/byte_utils.py @@ -8,7 +8,6 @@ import re import unicodedata - WHITESPACE_NORMALIZER = re.compile(r"\s+") SPACE = chr(32) SPACE_ESCAPE = chr(9601) diff --git a/icefall/ctc/prepare_lang.py b/icefall/ctc/prepare_lang.py index 4801b1beb..0e99e70d8 100644 --- a/icefall/ctc/prepare_lang.py +++ b/icefall/ctc/prepare_lang.py @@ -8,12 +8,12 @@ The lang_dir should contain the following files: """ import math +import re from collections import defaultdict from pathlib import Path from typing import List, Tuple import kaldifst -import re class Lexicon: diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 65b6f67b0..a3c480c9c 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -18,7 +18,7 @@ import random from dataclasses import dataclass -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import torch from torch import Tensor, nn diff --git a/icefall/profiler.py b/icefall/profiler.py index 49e138579..762105c48 100644 --- a/icefall/profiler.py +++ b/icefall/profiler.py @@ -5,14 +5,15 @@ # This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py +from collections import OrderedDict +from functools import partial +from typing import List, Optional + import k2 +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from functools import partial -from typing import List, Optional -from collections import OrderedDict -import numpy as np Tensor = torch.Tensor diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index dfede708b..1070d443a 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -5,16 +5,16 @@ import argparse import logging from pathlib import Path +from typing import Dict import onnx import torch from model import RnnLmModel from onnxruntime.quantization import QuantType, quantize_dynamic +from train import get_params from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, str2bool -from typing import Dict -from train import get_params def add_meta_data(filename: str, meta_data: Dict[str, str]): diff --git a/icefall/utils.py b/icefall/utils.py index 7d722b1bc..31f9801d9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -28,8 +28,6 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path -from pypinyin import pinyin, lazy_pinyin -from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union @@ -40,6 +38,8 @@ import sentencepiece as spm import torch import torch.distributed as dist import torch.nn as nn +from pypinyin import lazy_pinyin, pinyin +from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import average_checkpoints