From 0982db9cde89dd3a59b5c92cbbe70d245130b9d2 Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Wed, 16 Aug 2023 16:44:58 +0800 Subject: [PATCH] add a few args to support context list and rare words --- .../zipformer_prompt_asr/asr_datamodule.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 9a1aec1de..635272e17 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional import torch -from dataset import PromptASRDataset +from dataset2 import PromptASRDataset from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( CutConcatenate, @@ -70,6 +70,15 @@ class LibriHeavyAsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args + + if args.use_context_list: + from dataset2 import PromptASRDataset + assert args.rare_word_file is not None + with open(args.rare_word_file, 'r') as f: + self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform + else: + from dataset import PromptASRDataset + self.rare_word_list = None @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): @@ -201,6 +210,12 @@ class LibriHeavyAsrDataModule: help="Use the context list of libri heavy", ) + group.add_argument( + "--min-count", + type=int, + default=7, + ) + group.add_argument( "--with-decoding", type=str2bool, @@ -212,6 +227,11 @@ class LibriHeavyAsrDataModule: "--random-left-padding", type=str2bool, ) + + group.add_argument( + "--rare-word-file", + type=str, + ) def train_dataloaders( self, @@ -290,6 +310,7 @@ class LibriHeavyAsrDataModule: input_transforms=input_transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, ) if self.args.on_the_fly_feats: @@ -311,6 +332,7 @@ class LibriHeavyAsrDataModule: input_transforms=input_transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, ) if self.args.bucketing_sampler: @@ -378,12 +400,14 @@ class LibriHeavyAsrDataModule: ), return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, ) else: validate = PromptASRDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -430,7 +454,7 @@ class LibriHeavyAsrDataModule: if self.args.use_context_list: path = ( self.args.manifest_dir - / f"libriheavy_cuts_{self.args.subset}_with_context_list.jsonl.gz" + / f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz" ) elif self.args.with_decoding: path = (