You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

16 lines
293 B

import torch
features = torch.load("features.pth")
qf = features["qf"]
ql = features["ql"]
gf = features["gf"]
gl = features["gl"]
scores = qf.mm(gf.t())
res = scores.topk(5, dim=1)[1][:,0]
top1correct = gl[res].eq(ql).sum().item()
print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))