mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
update comments in evaluate.py
This commit is contained in:
parent
6a7ac689cf
commit
5a4b712c99
@ -17,13 +17,11 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
|
|
||||||
./zipformer/evaluate.py \
|
./zipformer/evaluate.py \
|
||||||
--num-epochs 50 \
|
--epoch 50 \
|
||||||
--start-epoch 10 \
|
--avg 10 \
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
@ -47,7 +45,11 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from at_datamodule import AudioSetATDatamodule
|
from at_datamodule import AudioSetATDatamodule
|
||||||
from lhotse import load_manifest
|
from lhotse import load_manifest
|
||||||
from sklearn.metrics import average_precision_score
|
|
||||||
|
try:
|
||||||
|
from sklearn.metrics import average_precision_score
|
||||||
|
except Exception as ex:
|
||||||
|
raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
|
||||||
from train import add_model_arguments, get_model, get_params, str2multihot
|
from train import add_model_arguments, get_model, get_params, str2multihot
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -130,7 +132,7 @@ def inference_one_batch(
|
|||||||
):
|
):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3, feature.shape
|
||||||
|
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
@ -138,7 +140,7 @@ def inference_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
audio_event = supervisions["audio_event"]
|
audio_event = supervisions["audio_event"]
|
||||||
|
|
||||||
label, orig_labels = str2multihot(audio_event)
|
label, _ = str2multihot(audio_event)
|
||||||
label = label.detach().cpu()
|
label = label.detach().cpu()
|
||||||
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user