add a few args to support context list and rare words

This commit is contained in:
marcoyang1998 2023-08-16 16:44:58 +08:00
parent 4420788f66
commit 0982db9cde

View File

@ -23,7 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import torch
from dataset import PromptASRDataset
from dataset2 import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
@ -70,6 +70,15 @@ class LibriHeavyAsrDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
if args.use_context_list:
from dataset2 import PromptASRDataset
assert args.rare_word_file is not None
with open(args.rare_word_file, 'r') as f:
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
else:
from dataset import PromptASRDataset
self.rare_word_list = None
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
@ -201,6 +210,12 @@ class LibriHeavyAsrDataModule:
help="Use the context list of libri heavy",
)
group.add_argument(
"--min-count",
type=int,
default=7,
)
group.add_argument(
"--with-decoding",
type=str2bool,
@ -212,6 +227,11 @@ class LibriHeavyAsrDataModule:
"--random-left-padding",
type=str2bool,
)
group.add_argument(
"--rare-word-file",
type=str,
)
def train_dataloaders(
self,
@ -290,6 +310,7 @@ class LibriHeavyAsrDataModule:
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
)
if self.args.on_the_fly_feats:
@ -311,6 +332,7 @@ class LibriHeavyAsrDataModule:
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
)
if self.args.bucketing_sampler:
@ -378,12 +400,14 @@ class LibriHeavyAsrDataModule:
),
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
)
else:
validate = PromptASRDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
@ -430,7 +454,7 @@ class LibriHeavyAsrDataModule:
if self.args.use_context_list:
path = (
self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}_with_context_list.jsonl.gz"
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
)
elif self.args.with_decoding:
path = (