mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
* add flops profiler, support for Zipformer encoder and Conformer encoder * support for reworked conformer and old zipformer * skip black check
95 lines
2.6 KiB
Python
Executable File
95 lines
2.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Usage: ./pruned_transducer_stateless/profile.py
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import sentencepiece as spm
|
|
import torch
|
|
|
|
from icefall.profiler import get_model_profile
|
|
from train import get_encoder_model, add_model_arguments, get_params
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
default="data/lang_bpe_500/bpe.model",
|
|
help="Path to the BPE model",
|
|
)
|
|
|
|
add_model_arguments(parser)
|
|
|
|
return parser
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
logging.info(f"Device: {device}")
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(params.bpe_model)
|
|
|
|
# <blk> is defined in local/train_bpe_model.py
|
|
params.blank_id = sp.piece_to_id("<blk>")
|
|
params.vocab_size = sp.get_piece_size()
|
|
|
|
logging.info(params)
|
|
|
|
logging.info("About to create model")
|
|
|
|
# We only profile the encoder part
|
|
model = get_encoder_model(params)
|
|
model.eval()
|
|
model.to(device)
|
|
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
logging.info(f"Number of model parameters: {num_param}")
|
|
|
|
# for 30-second input
|
|
B, T, D = 1, 3000, 80
|
|
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
|
|
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
|
|
|
|
flops, params = get_model_profile(model=model, args=(feature, feature_lens))
|
|
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
|
|
main()
|