combine the training data and decode without webdataset

This commit is contained in:
luomingshuang 2022-06-08 15:35:53 +08:00
parent 296303abdc
commit c8cb425e51
3 changed files with 47 additions and 64 deletions

View File

@ -209,13 +209,6 @@ class Aishell4AsrDataModule:
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for using",
)
def train_dataloaders(
self,
cuts_train: CutSet,
@ -379,18 +372,14 @@ class Aishell4AsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create dev dataloader")
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
dev_iter_dataset = IterableDatasetWrapper(
dataset=validate,
sampler=valid_sampler,
)
valid_dl = DataLoader(
dev_iter_dataset,
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
@ -409,27 +398,38 @@ class Aishell4AsrDataModule:
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test_iter_dataset,
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
def train_S_cuts(self) -> CutSet:
logging.info("About to get S train cuts")
return load_manifest_lazy(
self.args.manifest_dir
/ f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz"
)
@lru_cache()
def train_M_cuts(self) -> CutSet:
logging.info("About to get M train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz"
)
@lru_cache()
def train_L_cuts(self) -> CutSet:
logging.info("About to get L train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz"
)
@lru_cache()

View File

@ -74,6 +74,8 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from local.text_normalize import text_normalize
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -380,6 +382,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
hyps_dict = decode_one_batch(
params=params,
@ -393,8 +396,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((ref_text, hyp_words))
results[name].extend(this_batch)
@ -597,38 +599,17 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
aishell4 = Aishell4AsrDataModule(args)
test = "test"
if not os.path.exists(f"{test}/shared-0.tar"):
os.makedirs(test)
test_cuts = aishell4.test_cuts()
export_to_webdataset(
test_cuts,
output_path=f"{test}/shared-%d.tar",
shard_size=300,
)
test_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_dl = aishell4.test_dataloaders(cuts_test_webdataset)
test_cuts = aishell4.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)
test_dl = aishell4.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dl = [test_dl]

View File

@ -389,14 +389,14 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 1,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"log_interval": 50,
"reset_interval": 100,
"valid_interval": 200,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"model_warm_step": 50, # arg given to model, not for lrate
"env_info": get_env_info(),
}
)
@ -942,8 +942,10 @@ def run(rank, world_size, args):
diagnostic = diagnostics.attach_diagnostics(model, opts)
aishell4 = Aishell4AsrDataModule(args)
train_cuts = aishell4.train_cuts()
# Combine all of the training data
train_cuts = aishell4.train_S_cuts()
train_cuts += aishell4.train_M_cuts()
train_cuts += aishell4.train_L_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds