do some changes
This commit is contained in:
parent
02eb6b210b
commit
391cb707fd
@ -27,6 +27,7 @@ from lhotse import (
|
|||||||
CutSet,
|
CutSet,
|
||||||
Fbank,
|
Fbank,
|
||||||
FbankConfig,
|
FbankConfig,
|
||||||
|
load_manifest,
|
||||||
load_manifest_lazy,
|
load_manifest_lazy,
|
||||||
set_caching_enabled,
|
set_caching_enabled,
|
||||||
)
|
)
|
||||||
@ -191,13 +192,6 @@ class WenetSpeechAsrDataModule:
|
|||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--lazy-load",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--training-subset",
|
"--training-subset",
|
||||||
type=str,
|
type=str,
|
||||||
@ -218,7 +212,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest_lazy(
|
cuts_musan = load_manifest(
|
||||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -419,18 +413,10 @@ class WenetSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
if self.args.lazy_load:
|
return load_manifest_lazy(
|
||||||
logging.info("use lazy cuts")
|
self.args.manifest_dir
|
||||||
cuts_train = CutSet.from_jsonl_lazy(
|
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||||
self.args.manifest_dir
|
)
|
||||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cuts_train = CutSet.from_file(
|
|
||||||
self.args.manifest_dir
|
|
||||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
|
||||||
)
|
|
||||||
return cuts_train
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
|
|||||||
@ -883,7 +883,7 @@ def run(rank, world_size, args):
|
|||||||
dst_state_dict.update(pretrained_dict)
|
dst_state_dict.update(pretrained_dict)
|
||||||
model.load_state_dict(dst_state_dict)
|
model.load_state_dict(dst_state_dict)
|
||||||
|
|
||||||
initial_lr = 1.5e-4
|
initial_lr = 1e-3
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
@ -1012,6 +1012,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
|
return
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user