mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix style
This commit is contained in:
parent
18479fceb3
commit
7a8c9b7f53
@ -82,9 +82,8 @@ Check ./pretrained.py for its usage.
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -96,7 +95,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -302,7 +301,6 @@ def main():
|
|||||||
# torch scriptabe.
|
# torch scriptabe.
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
|
||||||
|
|
||||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||||
filename = "jit_script.pt"
|
filename = "jit_script.pt"
|
||||||
|
|
||||||
|
@ -48,16 +48,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from export import num_tokens
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -189,11 +185,12 @@ def main():
|
|||||||
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||||
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
|
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
|
||||||
|
|
||||||
results = []
|
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
topk_prob, topk_index = logit.sigmoid().topk(5)
|
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||||
topk_labels = [label_dict[index.item()] for index in topk_index]
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}")
|
print(
|
||||||
|
f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user