adapt for S, M and L training subset

This commit is contained in:
luomingshuang 2022-04-27 13:43:54 +08:00
parent 0930748b61
commit 4b567e480f
3 changed files with 46 additions and 9 deletions

View File

@ -198,6 +198,13 @@ class WenetSpeechAsrDataModule:
help="lazily open CutSets to avoid OOM (for L|XL subset)",
)
group.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for using",
)
def train_dataloaders(
self,
cuts_train: CutSet,
@ -316,6 +323,7 @@ class WenetSpeechAsrDataModule:
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
print(sampler_state_dict.keys())
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
@ -396,7 +404,6 @@ class WenetSpeechAsrDataModule:
world_size=1,
shuffle=False,
)
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper(
@ -417,14 +424,12 @@ class WenetSpeechAsrDataModule:
logging.info("use lazy cuts")
cuts_train = CutSet.from_jsonl_lazy(
self.args.manifest_dir
/ "cuts_L.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
/ f"cuts_{self.args.training_subset}.jsonl.gz"
)
else:
cuts_train = CutSet.from_file(
self.args.manifest_dir
/ "cuts_L.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
/ f"cuts_{self.args.training_subset}.jsonl.gz"
)
return cuts_train

View File

@ -101,6 +101,15 @@ def get_parser():
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--batch",
type=int,
default=None,
help="It specifies the batch checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
@ -499,6 +508,11 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
elif params.batch is not None:
filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt"
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints([filenames], device=device))
else:
start = params.epoch - params.avg + 1
filenames = []

View File

@ -251,7 +251,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=200,
default=8000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -279,6 +279,26 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--valid-interval",
type=int,
default=3000,
help="""When training_subset is L, set the valid_interval to 3000.
When training_subset is M, set the valid_interval to 1000.
When training_subset is S, set the valid_interval to 400.
""",
)
parser.add_argument(
"--model-warm-step",
type=int,
default=3000,
help="""When training_subset is L, set the model_warm_step to 3000.
When training_subset is M, set the model_warm_step to 500.
When training_subset is S, set the model_warm_step to 100.
""",
)
return parser
@ -333,9 +353,8 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 1,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
@ -348,7 +367,6 @@ def get_params() -> AttributeDict:
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
}
)