From d04271219c30cda2c517781b17243bd1fa75e6d6 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 16:01:53 +0800 Subject: [PATCH] Update utils.py --- icefall/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index bc262ec28..b85b6bf7b 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -472,16 +472,13 @@ class LossRecord(collections.defaultdict): all processes get the total. """ keys = sorted(self.keys()) - s = torch.tensor([ float(self[k]) for k in keys ], + s = torch.tensor([float(self[k]) for k in keys], device=device) dist.all_reduce(s, op=dist.ReduceOp.SUM) for k, v in zip(keys, s.cpu().tolist()): self[k] = v - def write_summary(self, - tb_writer: SummaryWriter, - prefix: str, - batch_idx: int) -> None: + def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: """Add logging information to a TensorBoard writer. Args: