mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Broadcast transformation
This commit is contained in:
parent
ceb4eb4b85
commit
a46f74feb4
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
# some utilities for diagnalizing models (rotating their parameters matrices
|
# some utilities for diagnalizing models (rotating their parameters matrices
|
||||||
# so that large and small parameter values are separated as much as possible).
|
# so that large and small parameter values are separated as much as possible).
|
||||||
@ -261,4 +262,14 @@ class OrthogonalTransformation(nn.Module):
|
|||||||
def get_transformation_out(self) -> Tensor:
|
def get_transformation_out(self) -> Tensor:
|
||||||
# see also get_transformation() above for notes on this.
|
# see also get_transformation() above for notes on this.
|
||||||
cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric
|
cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric
|
||||||
return get_transformation(cov)
|
|
||||||
|
|
||||||
|
t = get_transformation(cov)
|
||||||
|
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 0:
|
||||||
|
# make sure all processes in the process group share the same version of `t`.
|
||||||
|
# this would usually be the case, but if on this batch we modified self.feats_cov,
|
||||||
|
# it won't be the same among all processes because DDP synchronizes buffers at the
|
||||||
|
# beginning, not the end, of the forward().
|
||||||
|
logging.info("Broadcastint transformation")
|
||||||
|
dist.broadcast(t)
|
||||||
|
return t
|
||||||
|
Loading…
x
Reference in New Issue
Block a user