mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix doc
This commit is contained in:
parent
64dbcd07c5
commit
8b234b371a
@ -40,7 +40,6 @@ from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -659,15 +658,12 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
audio_tagging_loss = model(
|
||||
loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
target=labels,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
loss += audio_tagging_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -677,14 +673,13 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def str2multihot(events: List[str], n_classes=527, id_mapping=None):
|
||||
# Convert strings separated by semi-colon to multi-hot class labels
|
||||
# input: ["1;2", "2;3"]
|
||||
# input: ["0;1", "1;2"]
|
||||
# output: torch.tensor([[1,1,0], [0,1,1]])
|
||||
labels = [list(map(int, event.split(";"))) for event in events]
|
||||
batch_size = len(labels)
|
||||
|
Loading…
x
Reference in New Issue
Block a user