Skip to content

Commit

Permalink
Added a title argument to the two result plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrucker committed Aug 2, 2023
1 parent 3ed2a1b commit 372dacd
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
10 changes: 7 additions & 3 deletions coba/experiments/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def plot(self,

ax.autoscale(axis='both')

ax.set_title(title, loc='left', pad=15)
ax.set_title(title, loc='left')
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)

Expand Down Expand Up @@ -1002,6 +1002,7 @@ def plot_learners(self,
errevery: int = None,
labels : Sequence[str] = None,
colors : Union[int,Sequence[Union[str,int]]] = None,
title : str = None,
xlabel : str = None,
ylabel : str = None,
xlim : Tuple[Optional[Number],Optional[Number]] = None,
Expand All @@ -1028,6 +1029,7 @@ def plot_learners(self,
errevery: This determines the frequency of errorbars. If `None` they appear 5% of the time.
labels: The legend labels to use in the plot. These should be in order of the actual legend labels.
colors: The colors used to plot the learners plot.
title : The title give the plot.
xlabel: The label on the x-axis.
ylabel: The label on the y-axis.
xlim: Define the x-axis limits to plot. If `None` the x-axis limits will be inferred.
Expand Down Expand Up @@ -1073,7 +1075,7 @@ def plot_learners(self,
y_location = "Total" if x != 'index' else ""
y_avg_type = ("Instant" if span == 1 else f"Span {span}" if span else "Progressive")
y_samples = f"({len(Y)} Environments)"
title = ' '.join(filter(None,[y_location, y_avg_type, ylabel, y_samples]))
title = title if title is not None else (' '.join(filter(None,[y_location, y_avg_type, ylabel, y_samples])))

xrotation = 90 if x != 'index' and len(lines[0].X)>5 else 0
yrotation = 0
Expand Down Expand Up @@ -1101,6 +1103,7 @@ def plot_contrast(self,
errevery: int = None,
labels : Sequence[str] = None,
colors : Sequence[str] = None,
title : str = None,
xlabel : str = None,
ylabel : str = None,
xlim : Tuple[Optional[Number],Optional[Number]] = None,
Expand Down Expand Up @@ -1129,6 +1132,7 @@ def plot_contrast(self,
the standard error is shown, and if 'sd' the standard deviation is shown.
errevery: This determines the frequency of errorbars. If `None` they appear 5% of the time.
labels: The legend labels to use in the plot. These should be in order of the actual legend labels.
title : The title give the plot.
colors: The colors used to plot the learners plot.
xlabel: The label on the x-axis.
ylabel: The label on the y-axis.
Expand Down Expand Up @@ -1274,7 +1278,7 @@ def plot_contrast(self,

xlabel = xlabel or ("Interaction" if x=='index' else x[0] if len(x) == 1 else x)
ylabel = ylabel or (f"$\Delta$ {y}" if mode=="diff" else f"P($\Delta$ {y} > 0)")
title = f"{ylabel} ({len(_Y)} Environments)"
title = title if title is not None else (f"{ylabel} ({len(_Y)} Environments)")

self._plotter.plot(ax, lines, title, xlabel, ylabel, xlim, ylim, xticks, yticks, xrotation, yrotation, out)

Expand Down
52 changes: 52 additions & 0 deletions coba/tests/test_experiments_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,27 @@ def test_plot_learners_xlabel_ylabel(self):
self.assertEqual(1, len(plotter.plot_calls))
self.assertEqual(expected_lines, plotter.plot_calls[0][1])

def test_plot_learners_title(self):
envs = [['environment_id'],[0]]
lrns = [['learner_id', 'family'],[1,'learner_1'],[2,'learner_2']]
vals = [['evaluator_id'],[0]]
ints = [['environment_id','learner_id','evaluator_id','index','reward'],[0,1,0,1,1],[0,1,0,2,2],[0,2,0,1,1],[0,2,0,2,2]]

plotter = TestPlotter()
result = Result(envs, lrns, vals, ints)

result.set_plotter(plotter)
result.plot_learners(title='abc')

expected_lines = [
Points([1,2],[1,1.5],[],[0,0],0,1,'1. learner_1','-', 1),
Points([1,2],[1,1.5],[],[0,0],1,1,'2. learner_2','-', 1)
]

self.assertEqual("abc", plotter.plot_calls[0][2])
self.assertEqual(1, len(plotter.plot_calls))
self.assertEqual(expected_lines, plotter.plot_calls[0][1])

def test_plot_learners_all_str_to_none(self):
envs = [['environment_id'],[0]]
lrns = [['learner_id', 'family'],[1,'learner_1'],[2,None]]
Expand Down Expand Up @@ -2043,6 +2064,37 @@ def test_plot_contrast_xlabel_ylabel(self):
self.assertEqual('x', plotter.plot_calls[0][3])
self.assertEqual('y', plotter.plot_calls[0][4])

def test_plot_contrast_title(self):
envs = [['environment_id'],[0],[1],[2],[3]]
lrns = [['learner_id', 'family'],[1,'learner_1'],[2,'learner_2']]
vals = [['evaluator_id'],[0]]
ints = [['environment_id','learner_id','evaluator_id','index','reward'],
[0,1,0,1,0],[0,1,0,2,3],[0,1,0,3,9],
[0,2,0,1,1],[0,2,0,2,2],[0,2,0,3,6],
[1,1,0,1,1],[1,1,0,2,2],[1,1,0,3,6],
[1,2,0,1,0],[1,2,0,2,3],[1,2,0,3,9],
[2,1,0,1,0],[2,1,0,2,0],[2,1,0,3,6],
[2,2,0,1,0],[2,2,0,2,3],[2,2,0,3,9],
[3,1,0,1,0],[3,1,0,2,3],[3,1,0,3,9],
[3,2,0,1,0],[3,2,0,2,0],[3,2,0,3,6],
]

plotter = TestPlotter()
result = Result(envs, lrns, vals, ints)

result.set_plotter(plotter)
result.plot_contrast(2,1,title='abc')

expected_lines = [
Points(('2','1'), (-2,-1), None, (0,0), 0 , 1, 'l1 (2)', '.', 1.),
Points(('0','3'), ( 1, 2), None, (0,0), 2 , 1, 'l2 (2)', '.', 1.),
Points(('2','3'), ( 0, 0), None, None, "#888", 1, None , '-', .5),
]

self.assertEqual(1, len(plotter.plot_calls))
self.assertEqual(expected_lines, plotter.plot_calls[0][1])
self.assertEqual('abc', plotter.plot_calls[0][2])

def test_plot_contrast_one_environment_env_not_index(self):
envs = [['environment_id','a'],[0,1],[1,2],[2,3]]
lrns = [['learner_id', 'family'],[1,'learner_1'],[2,'learner_2']]
Expand Down

0 comments on commit 372dacd

Please sign in to comment.