mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix style
This commit is contained in:
parent
686d2d9787
commit
f3e8e42265
@ -17,7 +17,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
@ -160,8 +160,6 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
embedding_dict = {}
|
|
||||||
teacher_embedding_dict = {}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
num_batches = len(dl)
|
num_batches = len(dl)
|
||||||
|
@ -62,7 +62,7 @@ use the exported ONNX models.
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
@ -189,9 +189,9 @@ class OnnxAudioTagger(nn.Module):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing:
|
Return a tensor containing:
|
||||||
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
- logits, A 2-D tensor of shape (N, num_classes)
|
||||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
|
||||||
"""
|
"""
|
||||||
x, x_lens = self.encoder_embed(x, x_lens)
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
@ -200,13 +200,12 @@ class OnnxAudioTagger(nn.Module):
|
|||||||
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
|
||||||
|
|
||||||
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||||
padding_mask = make_pad_mask(encoder_out_lens)
|
# Note that this is slightly different from model.py for better
|
||||||
logits[padding_mask] = 0
|
# support of onnx
|
||||||
logits = logits.sum(dim=1) # mask the padding frames
|
N = logits.shape[0]
|
||||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
for i in range(N):
|
||||||
logits
|
logits[i, encoder_out_lens[i] :] = 0
|
||||||
) # normalize the logits
|
logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1)
|
||||||
print(logits.shape)
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@ -237,7 +236,7 @@ def export_audio_tagging_model_onnx(
|
|||||||
x = torch.zeros(1, 200, 80, dtype=torch.float32)
|
x = torch.zeros(1, 200, 80, dtype=torch.float32)
|
||||||
x_lens = torch.tensor([200], dtype=torch.int64)
|
x_lens = torch.tensor([200], dtype=torch.int64)
|
||||||
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.trace(model, (x, x_lens))
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
|
@ -135,7 +135,6 @@ class OnnxModel:
|
|||||||
meta = self.model.get_modelmeta().custom_metadata_map
|
meta = self.model.get_modelmeta().custom_metadata_map
|
||||||
print(meta)
|
print(meta)
|
||||||
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -162,6 +161,7 @@ class OnnxModel:
|
|||||||
)
|
)
|
||||||
return torch.from_numpy(out[0])
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
def read_sound_files(
|
||||||
filenames: List[str], expected_sample_rate: float
|
filenames: List[str], expected_sample_rate: float
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
@ -232,7 +232,7 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||||
logits = model(features, feature_lengths)
|
logits = model(features, feature_lengths)
|
||||||
|
|
||||||
for filename, logit in zip(args.sound_files, logits):
|
for filename, logit in zip(args.sound_files, logits):
|
||||||
topk_prob, topk_index = logit.sigmoid().topk(5)
|
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||||
topk_labels = [label_dict[index.item()] for index in topk_index]
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user