mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Exclude size=500 dim from projection; try to use double for model average
This commit is contained in:
parent
9ef11e64ba
commit
8e454bcf9e
@ -671,7 +671,9 @@ class Cain(Optimizer):
|
|||||||
# see for each dim in turn whether we want to perform any changes in co-ordinates,
|
# see for each dim in turn whether we want to perform any changes in co-ordinates,
|
||||||
# or store any stats.
|
# or store any stats.
|
||||||
size = grad.shape[dim]
|
size = grad.shape[dim]
|
||||||
if size <= 3 or size % 2 == 1 or size >= 2048 or size == numel:
|
if size <= 3 or size % 2 == 1 or size == 500 or size >= 2048 or size == numel:
|
||||||
|
# 500: exclude embedding dim, will later find a better way to do this.
|
||||||
|
|
||||||
# we don't do any such co-ordinate changes in dims with sizes
|
# we don't do any such co-ordinate changes in dims with sizes
|
||||||
# that are too small (no point) or large (too slow), or that are
|
# that are too small (no point) or large (too slow), or that are
|
||||||
# assumed convolutional (because they are odd). We can revisit
|
# assumed convolutional (because they are odd). We can revisit
|
||||||
|
@ -274,6 +274,19 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--average-decay",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Update the averaged model, namely `model_avg`, after processing
|
||||||
|
this number of batches. `model_avg` is a separate version of model,
|
||||||
|
in which each floating-point parameter is the average of all the
|
||||||
|
parameters from the start of training. Each time we take the average,
|
||||||
|
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||||
|
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -736,6 +749,7 @@ def train_one_epoch(
|
|||||||
params=params,
|
params=params,
|
||||||
model_cur=model,
|
model_cur=model,
|
||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -859,7 +873,7 @@ def run(rank, world_size, args):
|
|||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# model_avg is only used with rank 0
|
# model_avg is only used with rank 0
|
||||||
model_avg = copy.deepcopy(model)
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
@ -86,7 +86,7 @@ def save_checkpoint(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model_avg is not None:
|
if model_avg is not None:
|
||||||
checkpoint["model_avg"] = model_avg.state_dict()
|
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user