mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
adapt for S, M and L training subset
This commit is contained in:
parent
0930748b61
commit
4b567e480f
@ -198,6 +198,13 @@ class WenetSpeechAsrDataModule:
|
||||
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--training-subset",
|
||||
type=str,
|
||||
default="L",
|
||||
help="The training subset for using",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
@ -316,6 +323,7 @@ class WenetSpeechAsrDataModule:
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
print(sampler_state_dict.keys())
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
@ -396,7 +404,6 @@ class WenetSpeechAsrDataModule:
|
||||
world_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||
|
||||
test_iter_dataset = IterableDatasetWrapper(
|
||||
@ -417,14 +424,12 @@ class WenetSpeechAsrDataModule:
|
||||
logging.info("use lazy cuts")
|
||||
cuts_train = CutSet.from_jsonl_lazy(
|
||||
self.args.manifest_dir
|
||||
/ "cuts_L.jsonl.gz"
|
||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
cuts_train = CutSet.from_file(
|
||||
self.args.manifest_dir
|
||||
/ "cuts_L.jsonl.gz"
|
||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
|
@ -101,6 +101,15 @@ def get_parser():
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="It specifies the batch checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
@ -499,6 +508,11 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
elif params.batch is not None:
|
||||
filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt"
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints([filenames], device=device))
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
|
@ -251,7 +251,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=200,
|
||||
default=8000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -279,6 +279,26 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--valid-interval",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="""When training_subset is L, set the valid_interval to 3000.
|
||||
When training_subset is M, set the valid_interval to 1000.
|
||||
When training_subset is S, set the valid_interval to 400.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-warm-step",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="""When training_subset is L, set the model_warm_step to 3000.
|
||||
When training_subset is M, set the model_warm_step to 500.
|
||||
When training_subset is S, set the model_warm_step to 100.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -333,9 +353,8 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 1,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
@ -348,7 +367,6 @@ def get_params() -> AttributeDict:
|
||||
# parameters for joiner
|
||||
"joiner_dim": 512,
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user