mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
do some changes
This commit is contained in:
parent
02eb6b210b
commit
391cb707fd
@ -27,6 +27,7 @@ from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
set_caching_enabled,
|
||||
)
|
||||
@ -191,13 +192,6 @@ class WenetSpeechAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--lazy-load",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--training-subset",
|
||||
type=str,
|
||||
@ -218,7 +212,7 @@ class WenetSpeechAsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest_lazy(
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
|
||||
@ -419,18 +413,10 @@ class WenetSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
if self.args.lazy_load:
|
||||
logging.info("use lazy cuts")
|
||||
cuts_train = CutSet.from_jsonl_lazy(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
cuts_train = CutSet.from_file(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||
)
|
||||
return cuts_train
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
|
@ -883,7 +883,7 @@ def run(rank, world_size, args):
|
||||
dst_state_dict.update(pretrained_dict)
|
||||
model.load_state_dict(dst_state_dict)
|
||||
|
||||
initial_lr = 1.5e-4
|
||||
initial_lr = 1e-3
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
@ -1012,6 +1012,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
params: AttributeDict,
|
||||
):
|
||||
return
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
|
Loading…
x
Reference in New Issue
Block a user