This commit is contained in:
marcoyang 2024-03-26 15:49:57 +08:00
parent 64dbcd07c5
commit 8b234b371a

View File

@ -40,7 +40,6 @@ from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
@ -659,15 +658,12 @@ def compute_loss(
warm_step = params.warm_step
with torch.set_grad_enabled(is_training):
audio_tagging_loss = model(
loss = model(
x=feature,
x_lens=feature_lens,
target=labels,
)
loss = 0.0
loss += audio_tagging_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -677,14 +673,13 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item()
return loss, info
def str2multihot(events: List[str], n_classes=527, id_mapping=None):
# Convert strings separated by semi-colon to multi-hot class labels
# input: ["1;2", "2;3"]
# input: ["0;1", "1;2"]
# output: torch.tensor([[1,1,0], [0,1,1]])
labels = [list(map(int, event.split(";"))) for event in events]
batch_size = len(labels)