diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py index 41122dc7d..ddc52fc0e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py @@ -22,6 +22,7 @@ from typing import Optional, Tuple import logging import torch from torch import Tensor, nn +import torch.distributed as dist # some utilities for diagnalizing models (rotating their parameters matrices # so that large and small parameter values are separated as much as possible). @@ -261,4 +262,14 @@ class OrthogonalTransformation(nn.Module): def get_transformation_out(self) -> Tensor: # see also get_transformation() above for notes on this. cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric - return get_transformation(cov) + + + t = get_transformation(cov) + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 0: + # make sure all processes in the process group share the same version of `t`. + # this would usually be the case, but if on this batch we modified self.feats_cov, + # it won't be the same among all processes because DDP synchronizes buffers at the + # beginning, not the end, of the forward(). + logging.info("Broadcastint transformation") + dist.broadcast(t) + return t