fix format

This commit is contained in:
marcoyang 2024-08-22 14:22:56 +08:00
parent 01a85c5335
commit 298230177a

View File

@ -25,25 +25,24 @@ import argparse
import lhotse import lhotse
from lhotse import load_manifest from lhotse import load_manifest
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( parser.add_argument(
"--input-manifest", "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
type=str,
default="data/fbank/cuts_audioset_full.jsonl.gz"
) )
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
required=True, required=True,
) )
return parser return parser
def main(): def main():
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
parser = get_parser() parser = get_parser()
@ -53,7 +52,7 @@ def main():
print(f"A total of {len(cuts)} cuts.") print(f"A total of {len(cuts)} cuts.")
label_count = [0] * 527 # a total of 527 classes label_count = [0] * 527 # a total of 527 classes
for c in cuts: for c in cuts:
audio_event = c.supervisions[0].audio_event audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";"))) labels = list(map(int, audio_event.split(";")))
@ -68,6 +67,7 @@ def main():
for label in labels: for label in labels:
weight += 1000 / (label_count[label] + 0.01) weight += 1000 / (label_count[label] + 0.01)
f.write(f"{c.id} {weight}\n") f.write(f"{c.id} {weight}\n")
if __name__ == "__main__": if __name__ == "__main__":
main() main()