#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import numpy as np import torch import torch.distributed as dist # We might need to move this file to icefall/utils.py in the future # Copied from https://github.com/microsoft/Swin-Transformer/blob/main/utils.py def reduce_tensor(tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= dist.get_world_size() return rt def fix_random_seed(random_seed: int, rank: int): seed = random_seed + rank torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) # Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py class AverageMeter: """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self) -> str: return f"{self.val:.4f} (avg: {self.avg:.4f})" # Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" maxk = min(max(topk), output.size()[1]) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) return [ correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk ]