#!/usr/bin/env python3 # # Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: ./pruned_transducer_stateless/my_profile.py """ import argparse import logging import sentencepiece as spm import torch from icefall.profiler import get_model_profile from train import get_encoder_model, add_model_arguments, get_params def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--bpe-model", type=str, default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) add_model_arguments(parser) return parser @torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() params = get_params() params.update(vars(args)) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) logging.info(f"Device: {device}") sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) logging.info("About to create model") # We only profile the encoder part model = get_encoder_model(params) model.eval() model.to(device) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") # for 30-second input B, T, D = 1, 3000, 80 feature = torch.ones(B, T, D, dtype=torch.float32).to(device) feature_lens = torch.full((B,), T, dtype=torch.int64).to(device) flops, params = get_model_profile(model=model, args=(feature, feature_lens)) logging.info(f"For the encoder part, params: {params}, flops: {flops}") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main()