-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
29 lines (25 loc) · 934 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import random
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from pdb import set_trace
def set_random_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def aggregate_metrics(log):
results = {}
for k in log[0].keys():
if k == 'auc':
logits = np.concatenate([x[k]['logits'].numpy().reshape(-1) for x in log])
scores = np.concatenate([x[k]['scores'].numpy().reshape(-1) for x in log])
results[k] = roc_auc_score(scores, logits)
elif k == 'pred':
res = np.concatenate([x[k].numpy().reshape(-1) for x in log])
results[k] = res.sum()
else:
res = np.concatenate([x[k].numpy().reshape(-1) for x in log])
results[k] = np.mean(res)
return results