mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix doc
This commit is contained in:
parent
64dbcd07c5
commit
8b234b371a
@ -40,7 +40,6 @@ from pathlib import Path
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -659,15 +658,12 @@ def compute_loss(
|
|||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
audio_tagging_loss = model(
|
loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
target=labels,
|
target=labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
|
||||||
loss += audio_tagging_loss
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -677,14 +673,13 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
def str2multihot(events: List[str], n_classes=527, id_mapping=None):
|
def str2multihot(events: List[str], n_classes=527, id_mapping=None):
|
||||||
# Convert strings separated by semi-colon to multi-hot class labels
|
# Convert strings separated by semi-colon to multi-hot class labels
|
||||||
# input: ["1;2", "2;3"]
|
# input: ["0;1", "1;2"]
|
||||||
# output: torch.tensor([[1,1,0], [0,1,1]])
|
# output: torch.tensor([[1,1,0], [0,1,1]])
|
||||||
labels = [list(map(int, event.split(";"))) for event in events]
|
labels = [list(map(int, event.split(";"))) for event in events]
|
||||||
batch_size = len(labels)
|
batch_size = len(labels)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user