icefall/egs/wenetspeech/ASR/local/fix_manifest.py
2024-04-24 18:57:34 +08:00

114 lines
4.3 KiB
Python

#!/usr/bin/env python3
# Copyright 2024 author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import argparse
from lhotse import CutSet, load_manifest_lazy
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--fixed-transcript-path",
type=str,
default="data/fbank/text.fix",
help="""
See https://github.com/wenet-e2e/WenetSpeech/discussions/54
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix
""",
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/fbank/",
help="Directory to store the manifest files",
)
parser.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for wenetspeech.",
)
return parser
def load_fixed_text(fixed_text_path):
"""
fixed text format
X0000016287_92761015_S00001 我是徐涛
X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯
load into a dict
"""
fixed_text_dict = {}
with open(fixed_text_path, 'r') as f:
for line in f:
cut_id, text = line.strip().split(' ', 1)
fixed_text_dict[cut_id] = text
return fixed_text_dict
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
fixed_item = 0
for i, cut in enumerate(manifest):
if i % 10000 == 0:
logging.info(f'Processing cut {i}, fixed {fixed_item}')
cut_id_orgin = cut.id
if cut_id_orgin.endswith('_sp0.9'):
cut_id = cut_id_orgin[:-6]
elif cut_id_orgin.endswith('_sp1.1'):
cut_id = cut_id_orgin[:-6]
else:
cut_id = cut_id_orgin
if cut_id in fixed_text_dict:
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
cut.supervisions[0].text = fixed_text_dict[cut_id]
fixed_item += 1
manifest_writer.write(cut)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
fixed_text_path = args.manifest_dir + 'text.fix'
fixed_text_dict = load_fixed_text(fixed_text_path)
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
logging.info(f'Loading dev manifest from {dev_manifest_path}')
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
logging.info(f'Loading manifest from {manifest_path}')
cuts_manifest = load_manifest_lazy(manifest_path)
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
if __name__ == "__main__":
main()