mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix style
This commit is contained in:
parent
686d2d9787
commit
f3e8e42265
@ -17,7 +17,6 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
@ -160,8 +160,6 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
) -> Dict:
|
||||
num_cuts = 0
|
||||
embedding_dict = {}
|
||||
teacher_embedding_dict = {}
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
|
@ -62,7 +62,7 @@ use the exported ONNX models.
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict
|
||||
|
||||
import k2
|
||||
import onnx
|
||||
@ -189,9 +189,9 @@ class OnnxAudioTagger(nn.Module):
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||
Return a tensor containing:
|
||||
- logits, A 2-D tensor of shape (N, num_classes)
|
||||
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
@ -200,13 +200,12 @@ class OnnxAudioTagger(nn.Module):
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
|
||||
|
||||
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||
padding_mask = make_pad_mask(encoder_out_lens)
|
||||
logits[padding_mask] = 0
|
||||
logits = logits.sum(dim=1) # mask the padding frames
|
||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||
logits
|
||||
) # normalize the logits
|
||||
print(logits.shape)
|
||||
# Note that this is slightly different from model.py for better
|
||||
# support of onnx
|
||||
N = logits.shape[0]
|
||||
for i in range(N):
|
||||
logits[i, encoder_out_lens[i] :] = 0
|
||||
logits = logits.sum(dim=1) / encoder_out_lens.unsqueeze(-1)
|
||||
return logits
|
||||
|
||||
|
||||
@ -237,7 +236,7 @@ def export_audio_tagging_model_onnx(
|
||||
x = torch.zeros(1, 200, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([200], dtype=torch.int64)
|
||||
|
||||
model = torch.jit.script(model)
|
||||
model = torch.jit.trace(model, (x, x_lens))
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
|
@ -135,7 +135,6 @@ class OnnxModel:
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
print(meta)
|
||||
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -162,6 +161,7 @@ class OnnxModel:
|
||||
)
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user