mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
adapt for S, M and L training subset
This commit is contained in:
parent
0930748b61
commit
4b567e480f
@ -198,6 +198,13 @@ class WenetSpeechAsrDataModule:
|
|||||||
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--training-subset",
|
||||||
|
type=str,
|
||||||
|
default="L",
|
||||||
|
help="The training subset for using",
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
@ -316,6 +323,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
|
|
||||||
if sampler_state_dict is not None:
|
if sampler_state_dict is not None:
|
||||||
logging.info("Loading sampler state dict")
|
logging.info("Loading sampler state dict")
|
||||||
|
print(sampler_state_dict.keys())
|
||||||
train_sampler.load_state_dict(sampler_state_dict)
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
# 'seed' is derived from the current random state, which will have
|
# 'seed' is derived from the current random state, which will have
|
||||||
@ -396,7 +404,6 @@ class WenetSpeechAsrDataModule:
|
|||||||
world_size=1,
|
world_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||||
|
|
||||||
test_iter_dataset = IterableDatasetWrapper(
|
test_iter_dataset = IterableDatasetWrapper(
|
||||||
@ -417,14 +424,12 @@ class WenetSpeechAsrDataModule:
|
|||||||
logging.info("use lazy cuts")
|
logging.info("use lazy cuts")
|
||||||
cuts_train = CutSet.from_jsonl_lazy(
|
cuts_train = CutSet.from_jsonl_lazy(
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ "cuts_L.jsonl.gz"
|
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cuts_train = CutSet.from_file(
|
cuts_train = CutSet.from_file(
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ "cuts_L.jsonl.gz"
|
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
|
||||||
)
|
)
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
|
@ -101,6 +101,15 @@ def get_parser():
|
|||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="It specifies the batch checkpoint to use for decoding."
|
||||||
|
"Note: Epoch counts from 0.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -499,6 +508,11 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
elif params.avg == 1:
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
elif params.batch is not None:
|
||||||
|
filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt"
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints([filenames], device=device))
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
|
@ -251,7 +251,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=200,
|
default=8000,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -279,6 +279,26 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--valid-interval",
|
||||||
|
type=int,
|
||||||
|
default=3000,
|
||||||
|
help="""When training_subset is L, set the valid_interval to 3000.
|
||||||
|
When training_subset is M, set the valid_interval to 1000.
|
||||||
|
When training_subset is S, set the valid_interval to 400.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-warm-step",
|
||||||
|
type=int,
|
||||||
|
default=3000,
|
||||||
|
help="""When training_subset is L, set the model_warm_step to 3000.
|
||||||
|
When training_subset is M, set the model_warm_step to 500.
|
||||||
|
When training_subset is S, set the model_warm_step to 100.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -333,9 +353,8 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 1,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
@ -348,7 +367,6 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for joiner
|
# parameters for joiner
|
||||||
"joiner_dim": 512,
|
"joiner_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user