fix style

This commit is contained in:
marcoyang 2024-03-26 10:44:39 +08:00
parent 18479fceb3
commit 7a8c9b7f53
2 changed files with 5 additions and 10 deletions

View File

@ -82,9 +82,8 @@ Check ./pretrained.py for its usage.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from typing import Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
@ -96,7 +95,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
from icefall.utils import make_pad_mask, str2bool
def get_parser():
@ -302,7 +301,6 @@ def main():
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"

View File

@ -48,16 +48,12 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
@ -189,11 +185,12 @@ def main():
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
results = []
for i, logit in enumerate(logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}")
print(
f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}"
)
logging.info("Decoding Done")