This commit is contained in:
marcoyang 2024-03-26 15:49:57 +08:00
parent 64dbcd07c5
commit 8b234b371a

View File

@ -40,7 +40,6 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim import optim
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -659,15 +658,12 @@ def compute_loss(
warm_step = params.warm_step warm_step = params.warm_step
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
audio_tagging_loss = model( loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
target=labels, target=labels,
) )
loss = 0.0
loss += audio_tagging_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -677,14 +673,13 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item()
return loss, info return loss, info
def str2multihot(events: List[str], n_classes=527, id_mapping=None): def str2multihot(events: List[str], n_classes=527, id_mapping=None):
# Convert strings separated by semi-colon to multi-hot class labels # Convert strings separated by semi-colon to multi-hot class labels
# input: ["1;2", "2;3"] # input: ["0;1", "1;2"]
# output: torch.tensor([[1,1,0], [0,1,1]]) # output: torch.tensor([[1,1,0], [0,1,1]])
labels = [list(map(int, event.split(";"))) for event in events] labels = [list(map(int, event.split(";"))) for event in events]
batch_size = len(labels) batch_size = len(labels)