Merge 01be91217b7205a234af4f70879af75fdf9019b1 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Fangjun Kuang 2025-08-17 06:12:33 +00:00 committed by GitHub
commit 02aaa47d99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

101
icefall/bmuf.py Normal file
View File

@ -0,0 +1,101 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
import torch.optim
class BmufOptimizer(object):
"""This class implements
Scalable training of deep learning machines by incremental block training
with intra-block parallel optimization and blockwise model-update filtering
(https://ieeexplore.ieee.org/document/7472805)
using the following implementations as a reference:
- https://github.com/pytorch/fairseq/blob/main/fairseq/optim/bmuf.py
- https://github.com/tencent-ailab/pika/blob/main/trainer/bmuf.py
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
block_momentum: float,
sync_iter: int,
block_lr: float = 1.0,
):
"""
Args:
optimizer:
The underlying optimizer.
block_momentum:
The block momentum in the paper.
A reasonable value is (1 - 1./world_size).
sync_iter:
Do block synchronization every this iteration.
block_lr:
The block learning rate in the paper.
"""
assert isinstance(optimizer, torch.optimizer)
assert 0 <= block_momentum < 1
assert block_lr > 0
self._optimizer = optimizer
self.block_momentum = block_momentum
self.sync_iter = sync_iter
self.block_lr = block_lr
# Whenever `step()` is called, it is incremented.
# When num_updates % sync_iter == 0, we do block synchronization
self.num_updates = 0
def state_dict(self) -> Dict[str, Any]:
return self._optimizer.state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._optimizer.load_state_dict(state_dict)
@torch.no_grad()
def step(self, closure=None) -> None:
self._optimizer.step(closure=closure)
# Increment the number of updates
# If num_updates % syn_iter, invoke `_block_sync()`
def zero_grad(self) -> None:
self._optimizer.zero_grad()
def _block_sync(self) -> None:
# (1) Compute the gradient of each parameters:
# grad = prev_parameter - current_averaged_parameter
# (2) Compute the average gradients across all nodes
# (3) Compute smoothed grad
# smoothed_grad = block_momentum * prev_smoothed_grad +
# block_lr * grad
# (4) Update parameter
# parameter = prev_parameter - smoothed_grad
# TODO: Support Nesterov momentum when updating parameter
#
# Note: During communication, we can concatenate all parameters
# into a single vector and send/recv this parameter to reduce
# the communication overhead
#
# (5) Update internal buffers of `_optimizer` if there are any.
# For example, for the adam optimizer, we can average its buffers
# across nodes.
pass