fix style

This commit is contained in:
marcoyang 2024-03-26 10:44:39 +08:00
parent 18479fceb3
commit 7a8c9b7f53
2 changed files with 5 additions and 10 deletions

View File

@ -82,9 +82,8 @@ Check ./pretrained.py for its usage.
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import Tuple
import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn from torch import Tensor, nn
@ -96,7 +95,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import make_pad_mask, num_tokens, str2bool from icefall.utils import make_pad_mask, str2bool
def get_parser(): def get_parser():
@ -302,7 +301,6 @@ def main():
# torch scriptabe. # torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder = EncoderModel(model.encoder, model.encoder_embed) model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt" filename = "jit_script.pt"

View File

@ -48,16 +48,12 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.utils import make_pad_mask
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -189,11 +185,12 @@ def main():
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
results = []
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
topk_prob, topk_index = logit.sigmoid().topk(5) topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index] topk_labels = [label_dict[index.item()] for index in topk_index]
print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}") print(
f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}"
)
logging.info("Decoding Done") logging.info("Decoding Done")