mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
add a few args to support context list and rare words
This commit is contained in:
parent
4420788f66
commit
0982db9cde
@ -23,7 +23,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from dataset import PromptASRDataset
|
from dataset2 import PromptASRDataset
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
@ -71,6 +71,15 @@ class LibriHeavyAsrDataModule:
|
|||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
if args.use_context_list:
|
||||||
|
from dataset2 import PromptASRDataset
|
||||||
|
assert args.rare_word_file is not None
|
||||||
|
with open(args.rare_word_file, 'r') as f:
|
||||||
|
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
|
||||||
|
else:
|
||||||
|
from dataset import PromptASRDataset
|
||||||
|
self.rare_word_list = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
@ -201,6 +210,12 @@ class LibriHeavyAsrDataModule:
|
|||||||
help="Use the context list of libri heavy",
|
help="Use the context list of libri heavy",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--min-count",
|
||||||
|
type=int,
|
||||||
|
default=7,
|
||||||
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--with-decoding",
|
"--with-decoding",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -213,6 +228,11 @@ class LibriHeavyAsrDataModule:
|
|||||||
type=str2bool,
|
type=str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--rare-word-file",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
@ -290,6 +310,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
|
rare_word_list=self.rare_word_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
@ -311,6 +332,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
|
rare_word_list=self.rare_word_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
@ -378,12 +400,14 @@ class LibriHeavyAsrDataModule:
|
|||||||
),
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
|
rare_word_list=self.rare_word_list,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = PromptASRDataset(
|
validate = PromptASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
|
rare_word_list=self.rare_word_list,
|
||||||
)
|
)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
@ -430,7 +454,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
if self.args.use_context_list:
|
if self.args.use_context_list:
|
||||||
path = (
|
path = (
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list.jsonl.gz"
|
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
|
||||||
)
|
)
|
||||||
elif self.args.with_decoding:
|
elif self.args.with_decoding:
|
||||||
path = (
|
path = (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user