diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index 545ec4c0e..3be750b65 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -74,6 +74,7 @@ from train import ( add_model_arguments, get_adjusted_batch_count, get_model, + get_params, load_checkpoint_if_available, save_checkpoint, set_batch_count, @@ -88,7 +89,6 @@ from icefall.checkpoint import ( update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import ( @@ -320,72 +320,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -1017,8 +951,8 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) +# torch.set_num_threads(1) +# torch.set_num_interop_threads(1) if __name__ == "__main__": main()