mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update train_char.py
This commit is contained in:
parent
303eb99e47
commit
921d34abcb
@ -74,6 +74,7 @@ from train import (
|
|||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
get_adjusted_batch_count,
|
get_adjusted_batch_count,
|
||||||
get_model,
|
get_model,
|
||||||
|
get_params,
|
||||||
load_checkpoint_if_available,
|
load_checkpoint_if_available,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
set_batch_count,
|
set_batch_count,
|
||||||
@ -88,7 +89,6 @@ from icefall.checkpoint import (
|
|||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -320,72 +320,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
"""Return a dict containing training parameters.
|
|
||||||
|
|
||||||
All training related parameters that are not passed from the commandline
|
|
||||||
are saved in the variable `params`.
|
|
||||||
|
|
||||||
Commandline options are merged into `params` after they are parsed, so
|
|
||||||
you can also access them via `params`.
|
|
||||||
|
|
||||||
Explanation of options saved in `params`:
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
|
||||||
the model that has the lowest training loss. It is
|
|
||||||
updated during the training.
|
|
||||||
|
|
||||||
- best_valid_loss: Best validation loss so far. It is used to select
|
|
||||||
the model that has the lowest validation loss. It is
|
|
||||||
updated during the training.
|
|
||||||
|
|
||||||
- best_train_epoch: It is the epoch that has the best training loss.
|
|
||||||
|
|
||||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
|
||||||
|
|
||||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
|
||||||
contains number of batches trained so far across
|
|
||||||
epochs.
|
|
||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
|
||||||
|
|
||||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
|
||||||
in computing features.
|
|
||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
|
||||||
|
|
||||||
- encoder_dim: Hidden dim for multi-head attention model.
|
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
|
||||||
|
|
||||||
- warm_step: The warmup period that dictates the decay of the
|
|
||||||
scale on "simple" (un-pruned) loss.
|
|
||||||
"""
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
"best_train_loss": float("inf"),
|
|
||||||
"best_valid_loss": float("inf"),
|
|
||||||
"best_train_epoch": -1,
|
|
||||||
"best_valid_epoch": -1,
|
|
||||||
"batch_idx_train": 0,
|
|
||||||
"log_interval": 50,
|
|
||||||
"reset_interval": 200,
|
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
|
||||||
# parameters for zipformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
|
||||||
"warm_step": 2000,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -1017,8 +951,8 @@ def main():
|
|||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
# torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
# torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user