mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Broadcast transformation
This commit is contained in:
parent
ceb4eb4b85
commit
a46f74feb4
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
||||
import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.distributed as dist
|
||||
|
||||
# some utilities for diagnalizing models (rotating their parameters matrices
|
||||
# so that large and small parameter values are separated as much as possible).
|
||||
@ -261,4 +262,14 @@ class OrthogonalTransformation(nn.Module):
|
||||
def get_transformation_out(self) -> Tensor:
|
||||
# see also get_transformation() above for notes on this.
|
||||
cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric
|
||||
return get_transformation(cov)
|
||||
|
||||
|
||||
t = get_transformation(cov)
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 0:
|
||||
# make sure all processes in the process group share the same version of `t`.
|
||||
# this would usually be the case, but if on this batch we modified self.feats_cov,
|
||||
# it won't be the same among all processes because DDP synchronizes buffers at the
|
||||
# beginning, not the end, of the forward().
|
||||
logging.info("Broadcastint transformation")
|
||||
dist.broadcast(t)
|
||||
return t
|
||||
|
Loading…
x
Reference in New Issue
Block a user