diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 5089c84c9..fb0fc4c72 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -671,7 +671,9 @@ class Cain(Optimizer): # see for each dim in turn whether we want to perform any changes in co-ordinates, # or store any stats. size = grad.shape[dim] - if size <= 3 or size % 2 == 1 or size >= 2048 or size == numel: + if size <= 3 or size % 2 == 1 or size == 500 or size >= 2048 or size == numel: + # 500: exclude embedding dim, will later find a better way to do this. + # we don't do any such co-ordinate changes in dims with sizes # that are too small (no point) or large (too slow), or that are # assumed convolutional (because they are odd). We can revisit diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index c5391043e..26c2fd4c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -274,6 +274,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-decay", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -736,6 +749,7 @@ def train_one_epoch( params=params, model_cur=model, model_avg=model_avg, + ) if ( @@ -859,7 +873,7 @@ def run(rank, world_size, args): model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) + model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 170586455..4e02dd382 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -86,7 +86,7 @@ def save_checkpoint( } if model_avg is not None: - checkpoint["model_avg"] = model_avg.state_dict() + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() if params: for k, v in params.items():