diff --git a/icefall/utils.py b/icefall/utils.py index bc262ec28..b85b6bf7b 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -472,16 +472,13 @@ class LossRecord(collections.defaultdict): all processes get the total. """ keys = sorted(self.keys()) - s = torch.tensor([ float(self[k]) for k in keys ], + s = torch.tensor([float(self[k]) for k in keys], device=device) dist.all_reduce(s, op=dist.ReduceOp.SUM) for k, v in zip(keys, s.cpu().tolist()): self[k] = v - def write_summary(self, - tb_writer: SummaryWriter, - prefix: str, - batch_idx: int) -> None: + def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None: """Add logging information to a TensorBoard writer. Args: