Fix bug in diagnostics RE gpu

This commit is contained in:
Daniel Povey 2022-11-30 16:02:18 +08:00
parent b7cad258bb
commit b79a794706

View File

@ -401,7 +401,7 @@ class ScalarDiagnostic(object):
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32) sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32) sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
counts_cumsum = self.counts.cumsum(dim=0) counts_cumsum = counts.cumsum(dim=0)
counts_tot = counts_cumsum[-1] counts_tot = counts_cumsum[-1]
# subdivide the distribution up into `num_bins` intervals for analysis, for greater # subdivide the distribution up into `num_bins` intervals for analysis, for greater