mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
combine the training data and decode without webdataset
This commit is contained in:
parent
296303abdc
commit
c8cb425e51
@ -209,13 +209,6 @@ class Aishell4AsrDataModule:
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--training-subset",
|
||||
type=str,
|
||||
default="L",
|
||||
help="The training subset for using",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
@ -379,18 +372,14 @@ class Aishell4AsrDataModule:
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
|
||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||
|
||||
dev_iter_dataset = IterableDatasetWrapper(
|
||||
dataset=validate,
|
||||
sampler=valid_sampler,
|
||||
)
|
||||
valid_dl = DataLoader(
|
||||
dev_iter_dataset,
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
@ -409,27 +398,38 @@ class Aishell4AsrDataModule:
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||
|
||||
test_iter_dataset = IterableDatasetWrapper(
|
||||
dataset=test,
|
||||
sampler=sampler,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test_iter_dataset,
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
def train_S_cuts(self) -> CutSet:
|
||||
logging.info("About to get S train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
|
||||
self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_M_cuts(self) -> CutSet:
|
||||
logging.info("About to get M train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_L_cuts(self) -> CutSet:
|
||||
logging.info("About to get L train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
|
@ -74,6 +74,8 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from local.text_normalize import text_normalize
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -380,6 +382,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -393,8 +396,7 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -597,38 +599,17 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Note: Please use "pip install webdataset==0.1.103"
|
||||
# for installing the webdataset.
|
||||
import glob
|
||||
import os
|
||||
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
def text_normalize_for_cut(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = text.strip("\n").strip("\t")
|
||||
c.supervisions[0].text = text_normalize(text)
|
||||
return c
|
||||
|
||||
aishell4 = Aishell4AsrDataModule(args)
|
||||
|
||||
test = "test"
|
||||
if not os.path.exists(f"{test}/shared-0.tar"):
|
||||
os.makedirs(test)
|
||||
test_cuts = aishell4.test_cuts()
|
||||
export_to_webdataset(
|
||||
test_cuts,
|
||||
output_path=f"{test}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
test_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
|
||||
]
|
||||
cuts_test_webdataset = CutSet.from_webdataset(
|
||||
test_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
test_dl = aishell4.test_dataloaders(cuts_test_webdataset)
|
||||
test_cuts = aishell4.test_cuts()
|
||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||
test_dl = aishell4.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dl = [test_dl]
|
||||
|
@ -389,14 +389,14 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 1,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"log_interval": 50,
|
||||
"reset_interval": 100,
|
||||
"valid_interval": 200,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"model_warm_step": 50, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -942,8 +942,10 @@ def run(rank, world_size, args):
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
aishell4 = Aishell4AsrDataModule(args)
|
||||
|
||||
train_cuts = aishell4.train_cuts()
|
||||
# Combine all of the training data
|
||||
train_cuts = aishell4.train_S_cuts()
|
||||
train_cuts += aishell4.train_M_cuts()
|
||||
train_cuts += aishell4.train_L_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
|
Loading…
x
Reference in New Issue
Block a user