Skip to content

Commit

Permalink
Modified Result.plot_* methods to use Result.raw_* methods internally.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrucker committed Sep 3, 2023
1 parent 408cad7 commit 4927225
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 82 deletions.
119 changes: 40 additions & 79 deletions coba/experiments/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,10 @@ def raw_learners(self,

group = list(group)

if f'x={x}' not in data:
data[f'p={p}'] = list(map(itemgetter(2),group))
data[f'x={x}'] = list(map(itemgetter(1),group))

data[f'l={_l}'] = list(map(itemgetter(3),group))
if 'x' not in data:
data[f'p'] = list(map(itemgetter(2),group))
data[f'x'] = list(map(itemgetter(1),group))
data[_l] = list(map(itemgetter(3),group))

return Table(data)

Expand Down Expand Up @@ -1062,10 +1061,10 @@ def raw_contrast(self,
_x = _x[0] if _x[0] == _x[1] else f"{_x[1]}-{_x[0]}"

for _,(_y1,_y2),_p in group:
data[f'p={p}'].append(_p)
data[f'x={x}'].append(_x)
data[f'l1={og_l[0]}'].append(_y1)
data[f'l2={og_l[1]}'].append(_y2)
data['p'].append(_p)
data['x'].append(_x)
data['l1'].append(_y1)
data['l2'].append(_y2)
else:
#this implementation is considerably faster but only gives correct results under certain conditions
for _x, _group in groupby(plottable._indexed_ys(x,l,eid,lid,x,y=y,span=span),key=itemgetter(0)):
Expand All @@ -1075,10 +1074,10 @@ def raw_contrast(self,
_L2 = [g[1:] for g in _group if g[0] in l2 ]

for _,(_y1,_y2),_p in plottable._pairings(p,_L1,_L2):
data[f'p={p}'].append(_p)
data[f'x={x}'].append(_x)
data[f'l1={og_l[0]}'].append(_y1)
data[f'l2={og_l[1]}'].append(_y2)
data['p'].append(_p)
data['x'].append(_x)
data['l1'].append(_y1)
data['l2'].append(_y2)

if not data:
raise CobaException(f"We were unable to create any pairings to contrast. Make sure l1={og_l[0]} and l2={og_l[1]} is correct.")
Expand Down Expand Up @@ -1139,23 +1138,20 @@ def plot_learners(self,

if isinstance(labels,str): labels = [labels]

plottable = self._plottable(x,y)._finished(x,y,l,p)
n_interactions = len(next(plottable.interactions.groupby(3))[1])
raw_data = self.raw_learners(x,y,l,p,span)

errevery = errevery or max(int(n_interactions*0.05),1) if x == 'index' else 1
errevery = errevery or max(int(raw_data['x'][-1]*0.05),1) if x == 'index' else 1
style = "-" if x == 'index' else "."
err = plottable._confidence(err, errevery)
err = self._confidence(err, errevery)
x_prep = str if x != 'index' else (lambda _x: _x)

lines: List[Points] = []
for _l, group in groupby(plottable._indexed_ys(l,x,y=y,span=span),key=itemgetter(0)):
for _l in raw_data.columns[2:]:
color = self._get_color(colors, len(lines))
label = self._get_label(labels,_l,len(lines))

color = plottable._get_color(colors, len(lines))
label = plottable._get_label(labels,_l,len(lines))
group = map(itemgetter(slice(1,None)),group)
lines.append(Points(style=style,color=color,label=label))

for _xi, (_x, group) in enumerate(groupby(group, key=itemgetter(0))):
for _xi, (_x, group) in enumerate(groupby(zip(*raw_data[['x',_l]]), key=itemgetter(0))):
Y = [g[-1] for g in group]
lines[-1].add(x_prep(_x), *err(Y, _xi))

Expand All @@ -1175,8 +1171,8 @@ def plot_learners(self,

if top_n:
if abs(top_n) > len(lines): top_n = len(lines)*abs(top_n)/top_n
if top_n > 0: lines = [replace(l,color=plottable._get_color(colors,i),label=plottable._get_label(labels,l.label,i)) for i,l in enumerate(lines[:top_n],0 ) ]
if top_n < 0: lines = [replace(l,color=plottable._get_color(colors,i),label=plottable._get_label(labels,l.label,i)) for i,l in enumerate(lines[top_n:],top_n) ]
if top_n > 0: lines = [replace(l,color=self._get_color(colors,i),label=self._get_label(labels,l.label,i)) for i,l in enumerate(lines[:top_n],0 ) ]
if top_n < 0: lines = [replace(l,color=self._get_color(colors,i),label=self._get_label(labels,l.label,i)) for i,l in enumerate(lines[top_n:],top_n) ]

self._plotter.plot(ax, lines, title, xlabel, ylabel, xlim, ylim, xticks, yticks, xrotation, yrotation, out)

Expand Down Expand Up @@ -1241,67 +1237,32 @@ def plot_contrast(self,
xlim = xlim or [None,None]
ylim = ylim or [None,None]

og_l = (l1,l2)
raw_data = self.raw_contrast(l1,l2,x,y,l,p,span)

list_like=(list,tuple)

if isinstance(l,list_like) and not isinstance(l1[0],list_like): l1 = [l1]
if isinstance(l,list_like) and not isinstance(l2[0],list_like): l2 = [l2]
if not isinstance(l,list_like) and not isinstance(l1,list_like): l1 = [l1]
if not isinstance(l,list_like) and not isinstance(l2,list_like): l2 = [l2]
if isinstance(labels,str): labels = [labels]

if any(_l1 in l2 for _l1 in l1):
raise CobaException("A value cannot be in both `l1` and `l2`. Please make a change and run it again.")
if isinstance(l,list_like) and not isinstance(l1[0],list_like): l1 = [l1]
if isinstance(l,list_like) and not isinstance(l2[0],list_like): l2 = [l2]
if not isinstance(l,list_like) and not isinstance(l1 ,list_like): l1 = [l1]
if not isinstance(l,list_like) and not isinstance(l2 ,list_like): l2 = [l2]

contraster = (lambda x,y: y-x) if mode == 'diff' else (lambda x,y: int(y-x>0)) if mode=='prob' else mode
_boundary = 0 if mode == 'diff' else .5

plottable = self._plottable(x,y)
eid = 'environment_id'
lid = 'learner_id'

n_interactions = len(next(plottable.interactions.groupby(3))[1])

errevery = errevery or max(int(n_interactions*0.05),1) if x == 'index' else 1
errevery = errevery or max(int(raw_data['x'][-1]*0.05),1) if x == 'index' else 1
style = "-" if x == 'index' else "."
err = plottable._confidence(err, errevery)

if x != 'index':
#this implementation is considerably slower but always gives the correct results
L1,L2 = [],[]
for _l, group in groupby(plottable._indexed_ys(l,eid,lid,x,y=y,span=span),key=itemgetter(0)):

if _l in l1:
L1.extend(map(itemgetter(slice(1,None)),group))
if _l in l2:
L2.extend(map(itemgetter(slice(1,None)),group))

X_Y_YE = []
for _xi, (_x, group) in enumerate(groupby(sorted(plottable._pairings(p,L1,L2),key=cmp_to_key(comparer)),key=itemgetter(0))):
_x = f"{_x[0]}" if _x[0] == _x[1] else f"{_x[1]}-{_x[0]}"
_Y = [contraster(*pair) for _,pair,_ in group]
if _Y: X_Y_YE.append((_x,) + err(_Y,_xi))

else:
#this implementation is considerably faster but only gives correct results under certain conditions
X_Y_YE = []
for _xi, (_x, _group) in enumerate(groupby(plottable._indexed_ys(x,l,eid,lid,x,y=y,span=span),key=itemgetter(0))):

_group = list(map(itemgetter(slice(1,None)),_group))
_L1 = [g[1:] for g in _group if g[0] in l1]
_L2 = [g[1:] for g in _group if g[0] in l2]
_Y = [contraster(*pair) for _,pair,_ in plottable._pairings(p,_L1,_L2)]
err = self._confidence(err, errevery)

if _Y: X_Y_YE.append((str(_x) if x != 'index' else _x,) + err(_Y,_xi))
X_Y_YE = []
for _xi, (_x, group) in enumerate(groupby(zip(*raw_data[['x','l1','l2']]), key=itemgetter(0))):
_Y = [ contraster(g[1],g[2]) for g in group ]

if not X_Y_YE:
raise CobaException(f"We were unable to create any pairings to contrast. Make sure l1={og_l[0]} and l2={og_l[1]} is correct.")
if _Y: X_Y_YE.append((str(_x) if x!='index' else _x,) + err(_Y,_xi))

if x == 'index':
X,Y,YE = zip(*X_Y_YE)
color = plottable._get_color(colors, 0)
label = plottable._get_label(labels,'l2-l1',0)
color = self._get_color(colors, 0)
label = self._get_label(labels,'l2-l1',0)
label = f"{label}" if legend else None
lines = [Points(X,Y,None,YE, style=style, label=label, color=color)]

Expand All @@ -1320,7 +1281,7 @@ def plot_contrast(self,
X_Y_YE = sorted(X_Y_YE)

X,Y,YE = zip(*X_Y_YE)
color = plottable._get_color(colors, 0)
color = self._get_color(colors, 0)
lines = [Points(X,Y,None,YE, style=style, label=None, color=color)]

else:
Expand All @@ -1341,22 +1302,22 @@ def plot_contrast(self,

if l1_win:
X,Y,YE = zip(*l1_win)
color = plottable._get_color(colors, 0)
label = plottable._get_label(labels,'l1',0)
color = self._get_color(colors, 0)
label = self._get_label(labels,'l1',0)
label = f"{label} ({len(X)})" if legend else None
lines.append(Points(X,Y,None,YE, style=style, label=label, color=color))

if no_win:
X,Y,YE = zip(*no_win)
color = plottable._get_color(colors, 1)
color = self._get_color(colors, 1)
label = 'Tie'
label = f"{label} ({len(X)})" if legend else None
lines.append(Points(X,Y,None,YE, style=style, label=label, color=color))

if l2_win:
X,Y,YE = zip(*l2_win)
color = plottable._get_color(colors, 2)
label = plottable._get_label(labels,'l2',1)
color = self._get_color(colors, 2)
label = self._get_label(labels,'l2',1)
label = f"{label} ({len(X)})" if legend else None
lines.append(Points(X,Y,None,YE, style=style, label=label, color=color))

Expand Down
6 changes: 3 additions & 3 deletions coba/tests/test_experiments_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def test_raw_learners_all_default(self):
]

table = Result(envs, lrns, vals, ints).raw_learners()
self.assertEqual(('p=environment_id','x=index','l=1. learner_1','l=2. learner_2'), table.columns)
self.assertEqual(('p','x','1. learner_1','2. learner_2'), table.columns)
self.assertEqual([(0,1,1,1),(1,1,1,1),(0,2,1.5,1.5),(1,2,1.5,1.5)], list(table))

def test_raw_contrast_all_default(self):
Expand All @@ -1380,7 +1380,7 @@ def test_raw_contrast_all_default(self):
]

table = Result(envs, lrns, vals, ints).raw_contrast(1,2)
self.assertEqual(('p=environment_id','x=environment_id','l1=1','l2=2'), table.columns)
self.assertEqual(('p','x','l1','l2'), table.columns)
self.assertEqual([(0,0,1.5,1.5)], list(table))

def test_raw_contrast_index(self):
Expand All @@ -1395,7 +1395,7 @@ def test_raw_contrast_index(self):
]

table = Result(envs, lrns, vals, ints).raw_contrast(1,2,x='index')
self.assertEqual(('p=environment_id','x=index','l1=1','l2=2'), table.columns)
self.assertEqual(('p','x','l1','l2'), table.columns)
self.assertEqual([(0,1,1,1),(1,1,1,1),(0,2,1.5,1.5),(1,2,1.5,1.5)], list(table))

def test_raw_contrast_bad_l(self):
Expand Down

0 comments on commit 4927225

Please sign in to comment.