Skip to content

Commit

Permalink
Modified plot_contrast and raw_contrast to show full_name when l='lea…
Browse files Browse the repository at this point in the history
…rner_id' and x!='index'.
  • Loading branch information
mrucker committed Sep 14, 2023
1 parent 3866888 commit 4bec85e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
25 changes: 16 additions & 9 deletions coba/experiments/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,10 @@ def raw_contrast(self,

if x != 'index':
#this implementation is considerably slower but always gives the correct results

l1_label = self._lrn_cache[l1[0]]['full_name'] if l=='learner_id' else 'l1'
l2_label = self._lrn_cache[l2[0]]['full_name'] if l=='learner_id' else 'l2'

L1,L2 = [],[]
for _l, group in groupby(plottable._indexed_ys(l,eid,lid,x,y=y,span=span),key=itemgetter(0)):

Expand All @@ -1063,8 +1067,8 @@ def raw_contrast(self,
for _,(_y1,_y2),_p in group:
data['p'].append(_p)
data['x'].append(_x)
data['l1'].append(_y1)
data['l2'].append(_y2)
data[l1_label].append(_y1)
data[l2_label].append(_y2)
else:
#this implementation is considerably faster but only gives correct results under certain conditions
for _x, _group in groupby(plottable._indexed_ys(x,l,eid,lid,x,y=y,span=span),key=itemgetter(0)):
Expand Down Expand Up @@ -1254,15 +1258,18 @@ def plot_contrast(self,
err = self._confidence(err, errevery)

X_Y_YE = []
for _xi, (_x, group) in enumerate(groupby(zip(*raw_data[['x','l1','l2']]), key=itemgetter(0))):
for _xi, (_x, group) in enumerate(groupby(zip(*raw_data[raw_data.columns[-3:]]), key=itemgetter(0))):
_Y = [ contraster(g[1],g[2]) for g in group ]

if _Y: X_Y_YE.append((str(_x) if x!='index' else _x,) + err(_Y,_xi))

l1_label = raw_data.columns[-2]
l2_label = raw_data.columns[-1]

if x == 'index':
X,Y,YE = zip(*X_Y_YE)
color = self._get_color(colors, 0)
label = self._get_label(labels,'l2-l1',0)
color = self._get_color(colors, 0)
label = self._get_label(labels,f'{l2_label}-{l1_label}',0)
label = f"{label}" if legend else None
lines = [Points(X, Y, None, YE, style=style, label=label, color=color)]

Expand Down Expand Up @@ -1301,8 +1308,8 @@ def plot_contrast(self,
lines = []

X,Y,YE = zip(*l1_win) if l1_win else ((),(),None)
color = self._get_color(colors, 0)
label = self._get_label(labels,'l1',0)
color = self._get_color(colors, 0)
label = self._get_label(labels,l1_label,0)
label = f"{label} ({len(X)})" if legend else None
lines.append(Points(X, Y, None, YE, style=style, label=label, color=color))

Expand All @@ -1313,8 +1320,8 @@ def plot_contrast(self,
lines.append(Points(X, Y, None, YE, style=style, label=label, color=color))

X,Y,YE = zip(*l2_win) if l2_win else ((),(),None)
color = self._get_color(colors, 2)
label = self._get_label(labels,'l2',1)
color = self._get_color(colors, 2)
label = self._get_label(labels,l2_label,1)
label = f"{label} ({len(X)})" if legend else None
lines.append(Points(X, Y, None, YE, style=style, label=label, color=color))

Expand Down
26 changes: 13 additions & 13 deletions coba/tests/test_experiments_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ def test_raw_contrast_all_default(self):
]

table = Result(envs, lrns, vals, ints).raw_contrast(1,2)
self.assertEqual(('p','x','l1','l2'), table.columns)
self.assertEqual(('p','x','1. learner_1','2. learner_2'), table.columns)
self.assertEqual([(0,0,1.5,1.5)], list(table))

def test_raw_contrast_index(self):
Expand Down Expand Up @@ -2115,9 +2115,9 @@ def test_plot_contrast_four_environment_all_default(self):
result.plot_contrast(2,1)

expected_lines = [
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, 'l1 (2)' , '.', 1.),
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, '2. learner_2 (2)' , '.', 1.),
Points(() , () , None, None , 1 , 1, 'Tie (0)', '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, 'l2 (2)' , '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, '1. learner_1 (2)' , '.', 1.),
Points(('2','3'), ( 0, 0), None, None, "#888", 1, None , '-', .5),
]

Expand Down Expand Up @@ -2146,9 +2146,9 @@ def test_plot_contrast_xlabel_ylabel(self):
result.plot_contrast(2,1,xlabel='x',ylabel='y')

expected_lines = [
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, 'l1 (2)' , '.', 1.),
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, '2. learner_2 (2)' , '.', 1.),
Points(() , () , None, None , 1 , 1, 'Tie (0)', '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, 'l2 (2)' , '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, '1. learner_1 (2)' , '.', 1.),
Points(('2','3'), ( 0, 0), None, None, "#888", 1, None , '-', .5),
]

Expand Down Expand Up @@ -2180,9 +2180,9 @@ def test_plot_contrast_title(self):
result.plot_contrast(2,1,title='abc')

expected_lines = [
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, 'l1 (2)' , '.', 1.),
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, '2. learner_2 (2)' , '.', 1.),
Points(() , () , None, None , 1 , 1, 'Tie (0)', '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, 'l2 (2)' , '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, '1. learner_1 (2)' , '.', 1.),
Points(('2','3'), ( 0, 0), None, None, "#888", 1, None , '-', .5),
]

Expand Down Expand Up @@ -2210,9 +2210,9 @@ def test_plot_contrast_one_environment_env_not_index(self):
result.plot_contrast(2,1,'a')

expected_lines = [
Points(() , () , None, None , 0 , 1, 'l1 (0)' , '.', 1.),
Points(() , () , None, None , 0 , 1, '2. learner_2 (0)' , '.', 1.),
Points(('2',) , (0, ), None, (0, ), 1 , 1, 'Tie (1)', '.', 1.),
Points(('3','1'), (1,2), None, (0,0), 2 , 1, 'l2 (2)' , '.', 1.),
Points(('3','1'), (1,2), None, (0,0), 2 , 1, '1. learner_1 (2)' , '.', 1.),
Points(('2','1'), (0,0), None, None , "#888", 1, None , '-', .5)
]

Expand All @@ -2239,9 +2239,9 @@ def test_plot_contrast_one_environment_env_not_index_mode_prob_mixed_x(self):
result.plot_contrast(2,1,x='a',mode='prob')

expected_lines = [
Points(('2',) , (0, ), None, (0, ), 0 , 1, 'l1 (1)' , '.', 1.),
Points(('2',) , (0, ), None, (0, ), 0 , 1, '2. learner_2 (1)' , '.', 1.),
Points(() , () , None , None , 1 , 1, 'Tie (0)', '.', 1.),
Points(('1','None'), (1,1 ), None, (0,0), 2 , 1, 'l2 (2)' , '.', 1.),
Points(('1','None'), (1,1 ), None, (0,0), 2 , 1, '1. learner_1 (2)' , '.', 1.),
Points(('2','None'), (.5,.5), None, None , "#888", 1, None , '-', .5)
]

Expand All @@ -2268,9 +2268,9 @@ def test_plot_contrast_one_environment_env_not_index_mode_prob(self):
result.plot_contrast(2,1,x='a',mode='prob')

expected_lines = [
Points(('2',) , (0, ), None, (0, ), 0 , 1, 'l1 (1)' , '.', 1.),
Points(('2',) , (0, ), None, (0, ), 0 , 1, '2. learner_2 (1)' , '.', 1.),
Points(() , () , None, None , 1 , 1, 'Tie (0)', '.', 1.),
Points(('1','3'), (1,1 ), None, (0,0), 2 , 1, 'l2 (2)' , '.', 1.),
Points(('1','3'), (1,1 ), None, (0,0), 2 , 1, '1. learner_1 (2)' , '.', 1.),
Points(('2','3'), (.5,.5), None, None , "#888", 1, None , '-', .5)
]

Expand Down

0 comments on commit 4bec85e

Please sign in to comment.