diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 917c9d9a3..0e234c59f 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -648,7 +648,9 @@ def compute_loss( feature = feature.to(device) supervisions = batch["supervisions"] - events = supervisions["audio_event"] # the label indices are in CED format + events = supervisions[ + "audio_event" + ] # the label indices are in CED format (https://github.com/RicherMans/CED) labels, _ = str2multihot(events, n_classes=params.num_events) labels = labels.to(device)