From 8b234b371a5f4ecb3378923eecf1cf9a44ab59df Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 26 Mar 2024 15:49:57 +0800 Subject: [PATCH] fix doc --- egs/audioset/AT/zipformer/train.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index a6e6490ad..917c9d9a3 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -40,7 +40,6 @@ from pathlib import Path from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union -import k2 import optim import sentencepiece as spm import torch @@ -659,15 +658,12 @@ def compute_loss( warm_step = params.warm_step with torch.set_grad_enabled(is_training): - audio_tagging_loss = model( + loss = model( x=feature, x_lens=feature_lens, target=labels, ) - loss = 0.0 - loss += audio_tagging_loss - assert loss.requires_grad == is_training info = MetricsTracker() @@ -677,14 +673,13 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["audio_tagging_loss"] = audio_tagging_loss.detach().cpu().item() return loss, info def str2multihot(events: List[str], n_classes=527, id_mapping=None): # Convert strings separated by semi-colon to multi-hot class labels - # input: ["1;2", "2;3"] + # input: ["0;1", "1;2"] # output: torch.tensor([[1,1,0], [0,1,1]]) labels = [list(map(int, event.split(";"))) for event in events] batch_size = len(labels)