Source code for ax.benchmark.problems.synthetic.bandit

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from warnings import warn

import numpy as np

from ax.benchmark.benchmark_problem import BenchmarkProblem, get_soo_opt_config
from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction
from ax.core.parameter import ChoiceParameter, ParameterType
from ax.core.search_space import SearchSpace


[docs] def get_baseline(num_choices: int, n_sims: int = 100000000) -> float: """ Compute the baseline value. The baseline for this problem takes into account noise, because it uses the inference trace, and the bandit structure, which allows for running all arms in one noisy batch: Run a BatchTrial with every arm, with equal size. Choose the arm with the best observed value and take its true value. Take the expectation of the outcome of this process. """ noise_per_arm = num_choices**0.5 sim_observed_effects = ( np.random.normal(0, noise_per_arm, (n_sims, num_choices)) + np.arange(num_choices)[None, :] ) identified_best_arm = sim_observed_effects.argmin(axis=1) # because of the use of IdentityTestFunction baseline = identified_best_arm.mean() return baseline
[docs] def get_bandit_problem(num_choices: int = 30, num_trials: int = 3) -> BenchmarkProblem: parameter = ChoiceParameter( name="x0", parameter_type=ParameterType.INT, values=list(range(num_choices)), is_ordered=False, sort_values=False, ) search_space = SearchSpace(parameters=[parameter]) test_function = IdentityTestFunction() optimization_config = get_soo_opt_config( outcome_names=test_function.outcome_names, observe_noise_sd=True ) baselines = { 10: 1.40736478, 30: 2.4716703, 100: 4.403284, } if num_choices not in baselines: warn( f"Baseline value is not available for num_choices={num_choices}. Use " "`get_baseline` to compute the baseline and add it to `baselines`." ) baseline_value = baselines[30] else: baseline_value = baselines[num_choices] return BenchmarkProblem( name="Bandit", num_trials=num_trials, search_space=search_space, optimization_config=optimization_config, optimal_value=0, baseline_value=baseline_value, test_function=test_function, report_inference_value_as_trace=True, noise_std=1.0, status_quo_params={"x0": num_choices // 2}, )