add a few args to support context list and rare words

This commit is contained in:
marcoyang1998 2023-08-16 16:44:58 +08:00
parent 4420788f66
commit 0982db9cde

View File

@ -23,7 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from dataset import PromptASRDataset from dataset2 import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
CutConcatenate, CutConcatenate,
@ -71,6 +71,15 @@ class LibriHeavyAsrDataModule:
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
if args.use_context_list:
from dataset2 import PromptASRDataset
assert args.rare_word_file is not None
with open(args.rare_word_file, 'r') as f:
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
else:
from dataset import PromptASRDataset
self.rare_word_list = None
@classmethod @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
@ -201,6 +210,12 @@ class LibriHeavyAsrDataModule:
help="Use the context list of libri heavy", help="Use the context list of libri heavy",
) )
group.add_argument(
"--min-count",
type=int,
default=7,
)
group.add_argument( group.add_argument(
"--with-decoding", "--with-decoding",
type=str2bool, type=str2bool,
@ -213,6 +228,11 @@ class LibriHeavyAsrDataModule:
type=str2bool, type=str2bool,
) )
group.add_argument(
"--rare-word-file",
type=str,
)
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: CutSet,
@ -290,6 +310,7 @@ class LibriHeavyAsrDataModule:
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func, text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
) )
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
@ -311,6 +332,7 @@ class LibriHeavyAsrDataModule:
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func, text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
) )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
@ -378,12 +400,14 @@ class LibriHeavyAsrDataModule:
), ),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func, text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
) )
else: else:
validate = PromptASRDataset( validate = PromptASRDataset(
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func, text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
@ -430,7 +454,7 @@ class LibriHeavyAsrDataModule:
if self.args.use_context_list: if self.args.use_context_list:
path = ( path = (
self.args.manifest_dir self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}_with_context_list.jsonl.gz" / f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
) )
elif self.args.with_decoding: elif self.args.with_decoding:
path = ( path = (