Broadcast transformation

This commit is contained in:
Daniel Povey 2022-05-17 17:22:43 +08:00
parent ceb4eb4b85
commit a46f74feb4

View File

@ -22,6 +22,7 @@ from typing import Optional, Tuple
import logging import logging
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
import torch.distributed as dist
# some utilities for diagnalizing models (rotating their parameters matrices # some utilities for diagnalizing models (rotating their parameters matrices
# so that large and small parameter values are separated as much as possible). # so that large and small parameter values are separated as much as possible).
@ -261,4 +262,14 @@ class OrthogonalTransformation(nn.Module):
def get_transformation_out(self) -> Tensor: def get_transformation_out(self) -> Tensor:
# see also get_transformation() above for notes on this. # see also get_transformation() above for notes on this.
cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric
return get_transformation(cov)
t = get_transformation(cov)
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 0:
# make sure all processes in the process group share the same version of `t`.
# this would usually be the case, but if on this batch we modified self.feats_cov,
# it won't be the same among all processes because DDP synchronizes buffers at the
# beginning, not the end, of the forward().
logging.info("Broadcastint transformation")
dist.broadcast(t)
return t