This tutorial illustrates the core visualization utilities available in Ax.
import numpy as np
from ax import (
Arm,
ComparisonOp,
RangeParameter,
ParameterType,
SearchSpace,
SimpleExperiment,
OutcomeConstraint,
)
from ax.metrics.l2norm import L2NormMetric
from ax.modelbridge.cross_validation import cross_validate
from ax.modelbridge.registry import Models
from ax.plot.contour import interact_contour, plot_contour
from ax.plot.diagnostic import interact_cross_validation
from ax.plot.scatter import(
interact_fitted,
plot_objective_vs_constraints,
tile_fitted,
)
from ax.plot.slice import plot_slice
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
[INFO 08-29 13:34:46] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
The vizualizations require an experiment object and a model fit on the evaluated data. The routine below is a copy of the Developer API tutorial, so the explanation here is omitted. Retrieving the experiment and model objects for each API paradigm is shown in the respective tutorials
noise_sd = 0.1
param_names = [f"x{i+1}" for i in range(6)] # x1, x2, ..., x6
def noisy_hartmann_evaluation_function(parameterization):
x = np.array([parameterization.get(p_name) for p_name in param_names])
noise1, noise2 = np.random.normal(0, noise_sd, 2)
return {
"hartmann6": (hartmann6(x) + noise1, noise_sd),
"l2norm": (np.sqrt((x ** 2).sum()) + noise2, noise_sd)
}
hartmann_search_space = SearchSpace(
parameters=[
RangeParameter(
name=p_name, parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0
)
for p_name in param_names
]
)
exp = SimpleExperiment(
name="test_branin",
search_space=hartmann_search_space,
evaluation_function=noisy_hartmann_evaluation_function,
objective_name="hartmann6",
minimize=True,
outcome_constraints=[
OutcomeConstraint(
metric=L2NormMetric(
name="l2norm", param_names=param_names, noise_sd=0.2
),
op=ComparisonOp.LEQ,
bound=1.25,
relative=False,
)
],
)
After doing (N_BATCHES=15
) rounds of optimization, fit final GP using all data to feed into the plots.
N_RANDOM = 5
BATCH_SIZE = 1
N_BATCHES = 15
sobol = Models.SOBOL(exp.search_space)
exp.new_batch_trial(generator_run=sobol.gen(N_RANDOM))
for i in range(N_BATCHES):
intermediate_gp = Models.GPEI(experiment=exp, data=exp.eval())
exp.new_trial(generator_run=intermediate_gp.gen(BATCH_SIZE))
model = Models.GPEI(experiment=exp, data=exp.eval())
The plot below shows the response surface for hartmann6
metric as a function of the x1
, x2
parameters.
The other parameters are fixed in the middle of their respective ranges, which in this example is 0.5 for all of them.
The plot below allows toggling between different pairs of parameters to view the contours.
This plot illustrates the tradeoffs achievable for 2 different metrics. The plot takes the x-axis metric as input (usually the objective) and allows toggling among all other metrics for the y-axis.
This is useful to get a sense of the pareto frontier (i.e. what is the best objective value achievable for different bounds on the constraint)
CV plots are useful to check how well the model predictions calibrate against the actual measurements. If all points are close to the dashed line, then the model is a good predictor of the real data.
Slice plots show the metric outcome as a function of one parameter while fixing the others. They serve a similar function as contour plots.
render(plot_slice(model, "x2", "hartmann6"))
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-9-77cae64bb37d> in <module> ----> 1 render(plot_slice(model, "x2", "hartmann6")) ~/anaconda3/lib/python3.7/site-packages/ax/plot/slice.py in plot_slice(model, param_name, metric_name, generator_runs_dict, relative, density, slice_values, fixed_features) 215 } 216 --> 217 fig = go.Figure(data=traces, layout=layout) # pyre-ignore[16] 218 return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) 219 ~/anaconda3/lib/python3.7/site-packages/plotly/graph_objs/_figure.py in __init__(self, data, layout, frames, skip_invalid, **kwargs) 550 """ 551 super(Figure, --> 552 self).__init__(data, layout, frames, skip_invalid, **kwargs) 553 554 def add_area( ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in __init__(self, data, layout_plotly, frames, skip_invalid, **kwargs) 154 # ### Import traces ### 155 data = self._data_validator.validate_coerce(data, --> 156 skip_invalid=skip_invalid) 157 158 # ### Save tuple of trace objects ### ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, skip_invalid) 2333 else: 2334 trace = self.class_map[trace_type]( -> 2335 skip_invalid=skip_invalid, **v_copy) 2336 res.append(trace) 2337 else: ~/anaconda3/lib/python3.7/site-packages/plotly/graph_objs/__init__.py in __init__(self, arg, cliponaxis, connectgaps, customdata, customdatasrc, dx, dy, error_x, error_y, fill, fillcolor, groupnorm, hoverinfo, hoverinfosrc, hoverlabel, hoveron, hovertemplate, hovertemplatesrc, hovertext, hovertextsrc, ids, idssrc, legendgroup, line, marker, meta, metasrc, mode, name, opacity, orientation, r, rsrc, selected, selectedpoints, showlegend, stackgaps, stackgroup, stream, t, text, textfont, textposition, textpositionsrc, textsrc, tsrc, uid, uirevision, unselected, visible, x, x0, xaxis, xcalendar, xsrc, y, y0, yaxis, ycalendar, ysrc, **kwargs) 39591 self['legendgroup'] = legendgroup if legendgroup is not None else _v 39592 _v = arg.pop('line', None) > 39593 self['line'] = line if line is not None else _v 39594 _v = arg.pop('marker', None) 39595 self['marker'] = marker if marker is not None else _v ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in __setitem__(self, prop, value) 3306 # ### Handle compound property ### 3307 if isinstance(validator, CompoundValidator): -> 3308 self._set_compound_prop(prop, value) 3309 3310 # ### Handle compound array property ### ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in _set_compound_prop(self, prop, val) 3619 validator = self._validators.get(prop) 3620 # type: BasePlotlyType -> 3621 val = validator.validate_coerce(val, skip_invalid=self._skip_invalid) 3622 3623 # Save deep copies of current and new states ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, skip_invalid) 2129 2130 elif isinstance(v, dict): -> 2131 v = self.data_class(v, skip_invalid=skip_invalid) 2132 2133 elif isinstance(v, self.data_class): ~/anaconda3/lib/python3.7/site-packages/plotly/graph_objs/scatter/__init__.py in __init__(self, arg, color, dash, shape, simplify, smoothing, width, **kwargs) 2436 # ---------------------------------- 2437 _v = arg.pop('color', None) -> 2438 self['color'] = color if color is not None else _v 2439 _v = arg.pop('dash', None) 2440 self['dash'] = dash if dash is not None else _v ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in __setitem__(self, prop, value) 3315 # ### Handle simple property ### 3316 else: -> 3317 self._set_prop(prop, value) 3318 3319 # Handle non-scalar case ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in _set_prop(self, prop, val) 3560 return 3561 else: -> 3562 raise err 3563 3564 # val is None ~/anaconda3/lib/python3.7/site-packages/plotly/basedatatypes.py in _set_prop(self, prop, val) 3555 validator = self._validators.get(prop) 3556 try: -> 3557 val = validator.validate_coerce(val) 3558 except ValueError as err: 3559 if self._skip_invalid: ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, should_raise) 1162 validated_v = self.vc_scalar(v) 1163 if validated_v is None and should_raise: -> 1164 self.raise_invalid_val(v) 1165 1166 v = validated_v ~/anaconda3/lib/python3.7/site-packages/_plotly_utils/basevalidators.py in raise_invalid_val(self, v, inds) 275 typ=type_str(v), 276 v=repr(v), --> 277 valid_clr_desc=self.description())) 278 279 def raise_invalid_elements(self, invalid_els): ValueError: Invalid value of type 'builtins.str' received for the 'color' property of scatter.line Received value: 'transparent' The 'color' property is a color and may be specified as: - A hex string (e.g. '#ff0000') - An rgb/rgba string (e.g. 'rgb(255,0,0)') - An hsl/hsla string (e.g. 'hsl(0,100%,50%)') - An hsv/hsva string (e.g. 'hsv(0,100%,100%)') - A named CSS color: aliceblue, antiquewhite, aqua, aquamarine, azure, beige, bisque, black, blanchedalmond, blue, blueviolet, brown, burlywood, cadetblue, chartreuse, chocolate, coral, cornflowerblue, cornsilk, crimson, cyan, darkblue, darkcyan, darkgoldenrod, darkgray, darkgrey, darkgreen, darkkhaki, darkmagenta, darkolivegreen, darkorange, darkorchid, darkred, darksalmon, darkseagreen, darkslateblue, darkslategray, darkslategrey, darkturquoise, darkviolet, deeppink, deepskyblue, dimgray, dimgrey, dodgerblue, firebrick, floralwhite, forestgreen, fuchsia, gainsboro, ghostwhite, gold, goldenrod, gray, grey, green, greenyellow, honeydew, hotpink, indianred, indigo, ivory, khaki, lavender, lavenderblush, lawngreen, lemonchiffon, lightblue, lightcoral, lightcyan, lightgoldenrodyellow, lightgray, lightgrey, lightgreen, lightpink, lightsalmon, lightseagreen, lightskyblue, lightslategray, lightslategrey, lightsteelblue, lightyellow, lime, limegreen, linen, magenta, maroon, mediumaquamarine, mediumblue, mediumorchid, mediumpurple, mediumseagreen, mediumslateblue, mediumspringgreen, mediumturquoise, mediumvioletred, midnightblue, mintcream, mistyrose, moccasin, navajowhite, navy, oldlace, olive, olivedrab, orange, orangered, orchid, palegoldenrod, palegreen, paleturquoise, palevioletred, papayawhip, peachpuff, peru, pink, plum, powderblue, purple, red, rosybrown, royalblue, saddlebrown, salmon, sandybrown, seagreen, seashell, sienna, silver, skyblue, slateblue, slategray, slategrey, snow, springgreen, steelblue, tan, teal, thistle, tomato, turquoise, violet, wheat, white, whitesmoke, yellow, yellowgreen
Tile plots are useful for viewing the effect of each arm.