mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge 01be91217b7205a234af4f70879af75fdf9019b1 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
02aaa47d99
101
icefall/bmuf.py
Normal file
101
icefall/bmuf.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim
|
||||||
|
|
||||||
|
|
||||||
|
class BmufOptimizer(object):
|
||||||
|
"""This class implements
|
||||||
|
|
||||||
|
Scalable training of deep learning machines by incremental block training
|
||||||
|
with intra-block parallel optimization and blockwise model-update filtering
|
||||||
|
(https://ieeexplore.ieee.org/document/7472805)
|
||||||
|
|
||||||
|
using the following implementations as a reference:
|
||||||
|
|
||||||
|
- https://github.com/pytorch/fairseq/blob/main/fairseq/optim/bmuf.py
|
||||||
|
- https://github.com/tencent-ailab/pika/blob/main/trainer/bmuf.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
block_momentum: float,
|
||||||
|
sync_iter: int,
|
||||||
|
block_lr: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
optimizer:
|
||||||
|
The underlying optimizer.
|
||||||
|
block_momentum:
|
||||||
|
The block momentum in the paper.
|
||||||
|
A reasonable value is (1 - 1./world_size).
|
||||||
|
sync_iter:
|
||||||
|
Do block synchronization every this iteration.
|
||||||
|
block_lr:
|
||||||
|
The block learning rate in the paper.
|
||||||
|
"""
|
||||||
|
assert isinstance(optimizer, torch.optimizer)
|
||||||
|
assert 0 <= block_momentum < 1
|
||||||
|
assert block_lr > 0
|
||||||
|
|
||||||
|
self._optimizer = optimizer
|
||||||
|
self.block_momentum = block_momentum
|
||||||
|
self.sync_iter = sync_iter
|
||||||
|
self.block_lr = block_lr
|
||||||
|
|
||||||
|
# Whenever `step()` is called, it is incremented.
|
||||||
|
# When num_updates % sync_iter == 0, we do block synchronization
|
||||||
|
self.num_updates = 0
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._optimizer.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||||
|
self._optimizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None) -> None:
|
||||||
|
self._optimizer.step(closure=closure)
|
||||||
|
# Increment the number of updates
|
||||||
|
# If num_updates % syn_iter, invoke `_block_sync()`
|
||||||
|
|
||||||
|
def zero_grad(self) -> None:
|
||||||
|
self._optimizer.zero_grad()
|
||||||
|
|
||||||
|
def _block_sync(self) -> None:
|
||||||
|
# (1) Compute the gradient of each parameters:
|
||||||
|
# grad = prev_parameter - current_averaged_parameter
|
||||||
|
# (2) Compute the average gradients across all nodes
|
||||||
|
# (3) Compute smoothed grad
|
||||||
|
# smoothed_grad = block_momentum * prev_smoothed_grad +
|
||||||
|
# block_lr * grad
|
||||||
|
# (4) Update parameter
|
||||||
|
# parameter = prev_parameter - smoothed_grad
|
||||||
|
# TODO: Support Nesterov momentum when updating parameter
|
||||||
|
#
|
||||||
|
# Note: During communication, we can concatenate all parameters
|
||||||
|
# into a single vector and send/recv this parameter to reduce
|
||||||
|
# the communication overhead
|
||||||
|
#
|
||||||
|
# (5) Update internal buffers of `_optimizer` if there are any.
|
||||||
|
# For example, for the adam optimizer, we can average its buffers
|
||||||
|
# across nodes.
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user