diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py index 83034df95..1c0e702d6 100755 --- a/egs/audioset/AT/zipformer/export.py +++ b/egs/audioset/AT/zipformer/export.py @@ -82,9 +82,8 @@ Check ./pretrained.py for its usage. import argparse import logging from pathlib import Path -from typing import List, Tuple +from typing import Tuple -import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from torch import Tensor, nn @@ -96,7 +95,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, num_tokens, str2bool +from icefall.utils import make_pad_mask, str2bool def get_parser(): @@ -302,7 +301,6 @@ def main(): # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) - model.encoder = EncoderModel(model.encoder, model.encoder_embed) filename = "jit_script.pt" diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py index e3961736c..a162a8bb6 100755 --- a/egs/audioset/AT/zipformer/pretrained.py +++ b/egs/audioset/AT/zipformer/pretrained.py @@ -48,16 +48,12 @@ import logging import math from typing import List -import k2 import kaldifeat import torch import torchaudio -from export import num_tokens from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params -from icefall.utils import make_pad_mask - def get_parser(): parser = argparse.ArgumentParser( @@ -189,11 +185,12 @@ def main(): encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) - results = [] for i, logit in enumerate(logits): topk_prob, topk_index = logit.sigmoid().topk(5) topk_labels = [label_dict[index.item()] for index in topk_index] - print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}") + print( + f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}" + ) logging.info("Decoding Done")