Skip to content

Commit 38915cc

Browse files
committed
add initial good/bad shading support to plots
1 parent 91d5b67 commit 38915cc

File tree

1 file changed

+51
-18
lines changed

1 file changed

+51
-18
lines changed

pypop/notebook_interface/plotting.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# SPDX-License-Identifier: BSD-3-Clause-Clear
33
# Copyright (c) 2019, The Numerical Algorithms Group, Ltd. All rights reserved.
44

5+
from os import environ
56
from io import BytesIO
7+
from sys import float_info
68
import numpy
79
import pandas
810
import warnings
@@ -94,37 +96,40 @@ def figure(self):
9496
return self._figure
9597

9698
def _repr_html_(self):
97-
if self._figure is None:
98-
self._build_plot()
99-
100-
return file_html(self._figure, INLINE, "")
99+
return file_html(self.figure, INLINE, "")
101100

102101
def _repr_png_(self):
103-
if self._figure is None:
104-
self._build_plot()
102+
if not environ.get("PYPOP_HEADLESS"):
103+
return None
105104

106105
try:
107106
window_size = [int(1.1 * x) for x in self._plot_dims]
108107
except AttributeError:
109108
window_size = (900, 600)
110109

111-
self._figure.toolbar_location = None
110+
self.figure.toolbar_location = None
112111

113112
driver = get_any_webdriver()
114113
driver.set_window_size(*window_size)
115114

116-
img = get_screenshot_as_png(self._figure, driver=driver,)
115+
img = get_screenshot_as_png(self.figure, driver=driver,)
117116

118117
driver.quit()
119118

120119
imgbuffer = BytesIO()
121120
img.save(imgbuffer, format="png")
122121
return imgbuffer.getvalue()
123122

123+
def save_html(self, path):
124+
imgcode = self._repr_html_()
125+
126+
with open(path, "wt") as fh:
127+
fh.write(imgcode)
128+
124129
def save_png(self, path):
125130
imgdata = self._repr_png_()
126131

127-
with open(path, 'wb') as fh:
132+
with open(path, "wb") as fh:
128133
fh.write(imgdata)
129134

130135

@@ -477,7 +482,6 @@ def _build_plot(self):
477482
plot_data["Plotgroups"] = plot_data[self._group_key].apply(
478483
lambda x: "{} {}".format(x, self._group_label)
479484
)
480-
481485
else:
482486
plot_data = self._metrics.metric_data[
483487
[self._xaxis_key, self._yaxis_key]
@@ -486,20 +490,47 @@ def _build_plot(self):
486490

487491
color_map = build_discrete_cmap(plot_data["Plotgroups"].unique())
488492

489-
x_lims = plot_data[self._xaxis_key].min(), plot_data[self._xaxis_key].max()
490-
x_range = x_lims[1] - x_lims[0]
491-
x_range = x_lims[0] - 0.1 * x_range, x_lims[1] + 0.1 * x_range
492-
y_lims = plot_data[self._yaxis_key].min(), plot_data[self._yaxis_key].max()
493-
y_range = y_lims[1] - y_lims[0]
494-
y_range = y_lims[0] - 0.1 * y_range, y_lims[1] + 0.1 * y_range
493+
x_lims = numpy.asarray(
494+
[plot_data[self._xaxis_key].min(), plot_data[self._xaxis_key].max()]
495+
)
496+
y_lims = numpy.asarray(
497+
[plot_data[self._yaxis_key].min(), plot_data[self._yaxis_key].max()]
498+
)
499+
500+
xrange_ideal = x_lims
501+
yrange_ideal = xrange_ideal / x_lims[0]
502+
yrange_80pc = 0.8 * yrange_ideal + 0.2
503+
504+
x_axis_range = x_lims
505+
y_axis_range = min(y_lims[0], yrange_ideal[0]), max(y_lims[1], yrange_ideal[1])
506+
507+
x_expand = numpy.asarray([-0.1, 0.1]) * numpy.abs(numpy.diff(x_lims))
508+
y_expand = numpy.asarray([-0.1, 0.1]) * numpy.abs(numpy.diff(y_lims))
509+
x_axis_range = x_axis_range + x_expand
510+
y_axis_range = y_axis_range + y_expand
495511

496512
self._figure = figure(
497513
tools=["save"],
498514
min_border=0,
499515
aspect_ratio=1.5,
500516
sizing_mode="scale_width",
501-
x_range=x_range,
502-
y_range=y_range,
517+
x_range=x_axis_range,
518+
y_range=y_axis_range,
519+
)
520+
521+
self._figure.varea(
522+
xrange_ideal,
523+
y1=yrange_ideal,
524+
y2=yrange_80pc,
525+
fill_color="green",
526+
fill_alpha=0.4,
527+
)
528+
self._figure.varea(
529+
xrange_ideal,
530+
y1=yrange_80pc,
531+
y2=numpy.zeros_like(yrange_80pc),
532+
fill_color="red",
533+
fill_alpha=0.4,
503534
)
504535

505536
self._figure.xaxis.axis_label_text_font_size = "{}pt".format(self._fontsize)
@@ -510,6 +541,8 @@ def _build_plot(self):
510541
self._figure.xaxis.axis_label = self._xaxis_key
511542
self._figure.yaxis.axis_label = self._yaxis_key
512543

544+
xmin = float_info.max
545+
xmax = 0
513546
for group, groupdata in plot_data.groupby("Plotgroups", sort=False):
514547
groupdata = groupdata.sort_values(self._xaxis_key)
515548
self._figure.square(

0 commit comments

Comments
 (0)