update comments in evaluate.py

This commit is contained in:
marcoyang 2024-03-29 17:12:53 +08:00
parent 6a7ac689cf
commit 5a4b712c99

View File

@ -17,13 +17,11 @@
""" """
Usage: Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0"
./zipformer/evaluate.py \ ./zipformer/evaluate.py \
--num-epochs 50 \ --epoch 50 \
--start-epoch 10 \ --avg 10 \
--use-fp16 1 \
--exp-dir zipformer/exp \ --exp-dir zipformer/exp \
--max-duration 1000 --max-duration 1000
@ -47,7 +45,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from at_datamodule import AudioSetATDatamodule from at_datamodule import AudioSetATDatamodule
from lhotse import load_manifest from lhotse import load_manifest
from sklearn.metrics import average_precision_score
try:
from sklearn.metrics import average_precision_score
except Exception as ex:
raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
from train import add_model_arguments, get_model, get_params, str2multihot from train import add_model_arguments, get_model, get_params, str2multihot
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -130,7 +132,7 @@ def inference_one_batch(
): ):
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3, feature.shape
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
@ -138,7 +140,7 @@ def inference_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
audio_event = supervisions["audio_event"] audio_event = supervisions["audio_event"]
label, orig_labels = str2multihot(audio_event) label, _ = str2multihot(audio_event)
label = label.detach().cpu() label = label.detach().cpu()
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)