diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py index 1da6eb23c..a0deddc0c 100644 --- a/egs/audioset/AT/local/compute_weight.py +++ b/egs/audioset/AT/local/compute_weight.py @@ -25,25 +25,24 @@ import argparse import lhotse from lhotse import load_manifest + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( - "--input-manifest", - type=str, - default="data/fbank/cuts_audioset_full.jsonl.gz" + "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" ) parser.add_argument( "--output", type=str, required=True, - ) return parser + def main(): # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py parser = get_parser() @@ -53,7 +52,7 @@ def main(): print(f"A total of {len(cuts)} cuts.") - label_count = [0] * 527 # a total of 527 classes + label_count = [0] * 527 # a total of 527 classes for c in cuts: audio_event = c.supervisions[0].audio_event labels = list(map(int, audio_event.split(";"))) @@ -68,6 +67,7 @@ def main(): for label in labels: weight += 1000 / (label_count[label] + 0.01) f.write(f"{c.id} {weight}\n") - + + if __name__ == "__main__": main()