mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bug in diagnostics RE gpu
This commit is contained in:
parent
b7cad258bb
commit
b79a794706
@ -401,7 +401,7 @@ class ScalarDiagnostic(object):
|
|||||||
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
|
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
|
||||||
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
|
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
|
||||||
|
|
||||||
counts_cumsum = self.counts.cumsum(dim=0)
|
counts_cumsum = counts.cumsum(dim=0)
|
||||||
counts_tot = counts_cumsum[-1]
|
counts_tot = counts_cumsum[-1]
|
||||||
|
|
||||||
# subdivide the distribution up into `num_bins` intervals for analysis, for greater
|
# subdivide the distribution up into `num_bins` intervals for analysis, for greater
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user