mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix style
This commit is contained in:
parent
18479fceb3
commit
7a8c9b7f53
@ -82,9 +82,8 @@ Check ./pretrained.py for its usage.
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor, nn
|
||||
@ -96,7 +95,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
from icefall.utils import make_pad_mask, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -302,7 +301,6 @@ def main():
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
|
||||
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
filename = "jit_script.pt"
|
||||
|
||||
|
@ -48,16 +48,12 @@ import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from export import num_tokens
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -189,11 +185,12 @@ def main():
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
|
||||
|
||||
results = []
|
||||
for i, logit in enumerate(logits):
|
||||
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||
print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}")
|
||||
print(
|
||||
f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}"
|
||||
)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user