From 5a4b712c99937d6d2e054d25c5e633e8ba6f3ff3 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 29 Mar 2024 17:12:53 +0800 Subject: [PATCH] update comments in evaluate.py --- egs/audioset/AT/zipformer/evaluate.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py index 3fb1c9c02..487c0f901 100644 --- a/egs/audioset/AT/zipformer/evaluate.py +++ b/egs/audioset/AT/zipformer/evaluate.py @@ -17,13 +17,11 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" - +export CUDA_VISIBLE_DEVICES="0" ./zipformer/evaluate.py \ - --num-epochs 50 \ - --start-epoch 10 \ - --use-fp16 1 \ + --epoch 50 \ + --avg 10 \ --exp-dir zipformer/exp \ --max-duration 1000 @@ -47,7 +45,11 @@ import torch.nn as nn import torch.nn.functional as F from at_datamodule import AudioSetATDatamodule from lhotse import load_manifest -from sklearn.metrics import average_precision_score + +try: + from sklearn.metrics import average_precision_score +except Exception as ex: + raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn") from train import add_model_arguments, get_model, get_params, str2multihot from icefall.checkpoint import ( @@ -130,7 +132,7 @@ def inference_one_batch( ): device = next(model.parameters()).device feature = batch["inputs"] - assert feature.ndim == 3 + assert feature.ndim == 3, feature.shape feature = feature.to(device) # at entry, feature is (N, T, C) @@ -138,7 +140,7 @@ def inference_one_batch( supervisions = batch["supervisions"] audio_event = supervisions["audio_event"] - label, orig_labels = str2multihot(audio_event) + label, _ = str2multihot(audio_event) label = label.detach().cpu() feature_lens = supervisions["num_frames"].to(device)