mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
add a few args to support context list and rare words
This commit is contained in:
parent
4420788f66
commit
0982db9cde
@ -23,7 +23,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from dataset import PromptASRDataset
|
||||
from dataset2 import PromptASRDataset
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
@ -70,6 +70,15 @@ class LibriHeavyAsrDataModule:
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
if args.use_context_list:
|
||||
from dataset2 import PromptASRDataset
|
||||
assert args.rare_word_file is not None
|
||||
with open(args.rare_word_file, 'r') as f:
|
||||
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
|
||||
else:
|
||||
from dataset import PromptASRDataset
|
||||
self.rare_word_list = None
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
@ -201,6 +210,12 @@ class LibriHeavyAsrDataModule:
|
||||
help="Use the context list of libri heavy",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--min-count",
|
||||
type=int,
|
||||
default=7,
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--with-decoding",
|
||||
type=str2bool,
|
||||
@ -212,6 +227,11 @@ class LibriHeavyAsrDataModule:
|
||||
"--random-left-padding",
|
||||
type=str2bool,
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--rare-word-file",
|
||||
type=str,
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
@ -290,6 +310,7 @@ class LibriHeavyAsrDataModule:
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
@ -311,6 +332,7 @@ class LibriHeavyAsrDataModule:
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
@ -378,12 +400,14 @@ class LibriHeavyAsrDataModule:
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
else:
|
||||
validate = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
@ -430,7 +454,7 @@ class LibriHeavyAsrDataModule:
|
||||
if self.args.use_context_list:
|
||||
path = (
|
||||
self.args.manifest_dir
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list.jsonl.gz"
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
|
||||
)
|
||||
elif self.args.with_decoding:
|
||||
path = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user