mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
update comments in evaluate.py
This commit is contained in:
parent
6a7ac689cf
commit
5a4b712c99
@ -17,13 +17,11 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
./zipformer/evaluate.py \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 10 \
|
||||
--use-fp16 1 \
|
||||
--epoch 50 \
|
||||
--avg 10 \
|
||||
--exp-dir zipformer/exp \
|
||||
--max-duration 1000
|
||||
|
||||
@ -47,7 +45,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from at_datamodule import AudioSetATDatamodule
|
||||
from lhotse import load_manifest
|
||||
from sklearn.metrics import average_precision_score
|
||||
|
||||
try:
|
||||
from sklearn.metrics import average_precision_score
|
||||
except Exception as ex:
|
||||
raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
|
||||
from train import add_model_arguments, get_model, get_params, str2multihot
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -130,7 +132,7 @@ def inference_one_batch(
|
||||
):
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
assert feature.ndim == 3, feature.shape
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
@ -138,7 +140,7 @@ def inference_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
audio_event = supervisions["audio_event"]
|
||||
|
||||
label, orig_labels = str2multihot(audio_event)
|
||||
label, _ = str2multihot(audio_event)
|
||||
label = label.detach().cpu()
|
||||
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user