combine the training data and decode without webdataset

This commit is contained in:
luomingshuang 2022-06-08 15:35:53 +08:00
parent 296303abdc
commit c8cb425e51
3 changed files with 47 additions and 64 deletions

View File

@ -209,13 +209,6 @@ class Aishell4AsrDataModule:
help="AudioSamples or PrecomputedFeatures", help="AudioSamples or PrecomputedFeatures",
) )
group.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for using",
)
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: CutSet,
@ -379,18 +372,14 @@ class Aishell4AsrDataModule:
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False, shuffle=False,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
dev_iter_dataset = IterableDatasetWrapper(
dataset=validate,
sampler=valid_sampler,
)
valid_dl = DataLoader( valid_dl = DataLoader(
dev_iter_dataset, validate,
sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
@ -409,27 +398,38 @@ class Aishell4AsrDataModule:
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False, shuffle=False,
) )
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper logging.info("About to create test dataloader")
test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
)
test_dl = DataLoader( test_dl = DataLoader(
test_iter_dataset, test,
batch_size=None, batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
return test_dl return test_dl
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_S_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get S train cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz"
/ f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz" )
@lru_cache()
def train_M_cuts(self) -> CutSet:
logging.info("About to get M train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz"
)
@lru_cache()
def train_L_cuts(self) -> CutSet:
logging.info("About to get L train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz"
) )
@lru_cache() @lru_cache()

View File

@ -74,6 +74,8 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from lhotse.cut import Cut
from local.text_normalize import text_normalize
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -380,6 +382,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -393,8 +396,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split() this_batch.append((ref_text, hyp_words))
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -597,38 +599,17 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103" def text_normalize_for_cut(c: Cut):
# for installing the webdataset. # Text normalize for each sample
import glob text = c.supervisions[0].text
import os text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
from lhotse import CutSet return c
from lhotse.dataset.webdataset import export_to_webdataset
aishell4 = Aishell4AsrDataModule(args) aishell4 = Aishell4AsrDataModule(args)
test = "test"
if not os.path.exists(f"{test}/shared-0.tar"):
os.makedirs(test)
test_cuts = aishell4.test_cuts() test_cuts = aishell4.test_cuts()
export_to_webdataset( test_cuts = test_cuts.map(text_normalize_for_cut)
test_cuts, test_dl = aishell4.test_dataloaders(test_cuts)
output_path=f"{test}/shared-%d.tar",
shard_size=300,
)
test_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_dl = aishell4.test_dataloaders(cuts_test_webdataset)
test_sets = ["test"] test_sets = ["test"]
test_dl = [test_dl] test_dl = [test_dl]

View File

@ -389,14 +389,14 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 1, "log_interval": 50,
"reset_interval": 200, "reset_interval": 100,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 200,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 50, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -942,8 +942,10 @@ def run(rank, world_size, args):
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
aishell4 = Aishell4AsrDataModule(args) aishell4 = Aishell4AsrDataModule(args)
# Combine all of the training data
train_cuts = aishell4.train_cuts() train_cuts = aishell4.train_S_cuts()
train_cuts += aishell4.train_M_cuts()
train_cuts += aishell4.train_L_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds