from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models, ModelRegistryBase
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.modelbridge_utils import get_pending_observation_features
from ax.utils.testing.core_stubs import get_branin_search_space, get_branin_experiment
GenerationStrategy (API reference) is a key abstraction in Ax:
Scheduler etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/).This tutorial walks through a few examples of generation strategies and discusses its important settings. Before reading it, we recommend familiarizing yourself with how Model and ModelBridge work in Ax: https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack.
Contents:
GenerationStep a building block of the generation strategyGenerationStep settingsGenerationStep-s togethermax_parallelism enforcement and handling the MaxParallelismReachedExceptionGenerationStrategy storageGeneratorRun-s, not Trial-smodel_kwargs elements that don't have associated serialization logic in AxModels registry enum entries over a factory function?Models?gs = GenerationStrategy(
steps=[
# 1. Initialization step (does not require pre-existing data and is well-suited for
# initial sampling of the search space)
GenerationStep(
model=Models.SOBOL,
num_trials=5, # How many trials should be produced from this generation step
min_trials_observed=3, # How many trials need to be completed to move to next model
max_parallelism=5, # Max parallelism for this step
model_kwargs={"seed": 999}, # Any kwargs you want passed into the model
model_gen_kwargs={}, # Any kwargs you want passed to `modelbridge.gen`
),
# 2. Bayesian optimization step (requires data obtained from previous phase and learns
# from all data available at the time of each new candidate generation call)
GenerationStep(
model=Models.GPEI,
num_trials=-1, # No limitation on how many trials should be produced from this step
max_parallelism=3, # Parallelism limit for this step, often lower than for Sobol
# More on parallelism vs. required samples in BayesOpt:
# https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials
),
]
)
Ax provides a choose_generation_strategy utility, which can auto-select a suitable generation strategy given a search space and an array of other optional settings. The utility is fairly simple at the moment, but additional development (support for multi-objective optimization, multi-fidelity optimization, Bayesian optimization with categorical kernels etc.) is coming soon.
gs = choose_generation_strategy(
# Required arguments:
search_space=get_branin_search_space(), # Ax `SearchSpace`
# Some optional arguments (shown with their defaults), see API docs for more settings:
# https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils
use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`
no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt
max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS
)
gs
[INFO 04-26 20:20:22] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters. [INFO 04-26 20:20:22] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])
While often used through Service or Loop API or other higher-order abstractions like the Ax Scheduler (where the generation strategy is used to fit models and produce candidates from them under-the-hood), it's also possible to use the GS directly, in place of a ModelBridge instance. The interface of GenerationStrategy.gen is the same as ModelBridge.gen.
experiment = get_branin_experiment()
[INFO 04-26 20:20:22] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
Note that it's important to specify pending observations to the call to gen to avoid getting the same points re-suggested. Without pending_observations argument, Ax models are not aware of points that should be excluded from generation. Points are considered "pending" when they belong to STAGED, RUNNING, or ABANDONED trials (with the latter included so model does not re-suggest points that are considered "bad" and should not be re-suggested).
If the call to get_pending_obervation_features becomes slow in your setup (since it performs data-fetching etc.), you can opt for get_pending_observation_features_based_on_trial_status (also from ax.modelbridge.modelbridge_utils), but note the limitations of that utility (detailed in its docstring).
generator_run = gs.gen(
experiment=experiment, # Ax `Experiment`, for which to generate new candidates
data=None, # Ax `Data` to use for model training, optional.
n=1, # Number of candidate arms to produce
pending_observations=get_pending_observation_features(experiment), # Points that should not be re-generated
# Any other kwargs specified will be passed through to `ModelBridge.gen` along with `GenerationStep.model_gen_kwargs`
)
generator_run
GeneratorRun(1 arms, total weight 1.0)
Then we can add the newly produced GeneratorRun to the experiment as a Trial (or BatchTrial if n > 1):
trial = experiment.new_trial(generator_run)
trial
Trial(experiment_name='branin_test_experiment', index=0, status=TrialStatus.CANDIDATE, arm=Arm(name='0_0', parameters={'x1': -1.3248933106660843, 'x2': 3.8894078135490417}))
Important notes on GenerationStrategy.gen:
data argument above is not specified, GS will pull experiment data from cache via experiment.lookup_data,pending_observations, the GS (and any model in Ax) could produce the same candidate over and over, as without that argument the model is not 'aware' that the candidate is part of a RUNNING or ABANDONED trial and should not be re-suggested again.In cases where get_pending_observation_features is too slow and the experiment consists of 1-arm Trial-s only, it's possible to use get_pending_observation_features_based_on_trial_status instead (found in the same file).
Note that when using the Ax Service API, one of the arguments to AxClient is choose_generation_strategy_kwargs; specifying that argument is a convenient way to influence the choice of generation strategy in AxClient without manually specifying a full GenerationStrategy.
GenerationStep as a building block of generation strategy¶GenerationStep¶There are two ways of specifying a model for a generation step: via an entry in a Models enum or via a 'factory function' –– a callable model constructor (e.g. get_GPEI and other factory functions in the same file). Note that using the latter path, a factory function, will prohibit GenerationStrategy storage and is generally discouraged.
GenerationStep settings¶All of the available settings are described in the documentation:
print(GenerationStep.__doc__)
One step in the generation strategy, corresponds to a single model.
Describes the model, how many trials will be generated with this model, what
minimum number of observations is required to proceed to the next model, etc.
NOTE: Model can be specified either from the model registry
(`ax.modelbridge.registry.Models` or using a callable model constructor. Only
models from the registry can be saved, and thus optimization can only be
resumed if interrupted when using models from the registry.
Args:
model: A member of `Models` enum or a callable returning an instance of
`ModelBridge` with an instantiated underlying `Model`. Refer to
`ax/modelbridge/factory.py` for examples of such callables.
num_trials: How many trials to generate with the model from this step.
If set to -1, trials will continue to be generated from this model
as long as `generation_strategy.gen` is called (available only for
the last of the generation steps).
min_trials_observed: How many trials must be completed before the
generation strategy can proceed to the next step. Defaults to 0.
If `num_trials` of a given step have been generated but `min_trials_
observed` have not been completed, a call to `generation_strategy.gen`
will fail with a `DataRequiredError`.
max_parallelism: How many trials generated in the course of this step are
allowed to be run (i.e. have `trial.status` of `RUNNING`) simultaneously.
If `max_parallelism` trials from this step are already running, a call
to `generation_strategy.gen` will fail with a `MaxParallelismReached
Exception`, indicating that more trials need to be completed before
generating and running next trials.
use_update: Whether to use `model_bridge.update` instead or reinstantiating
model + bridge on every call to `gen` within a single generation step.
NOTE: use of `update` on stateful models that do not implement `_get_state`
may result in inability to correctly resume a generation strategy from
a serialized state.
enforce_num_trials: Whether to enforce that only `num_trials` are generated
from the given step. If False and `num_trials` have been generated, but
`min_trials_observed` have not been completed, `generation_strategy.gen`
will continue generating trials from the current step, exceeding `num_
trials` for it. Allows to avoid `DataRequiredError`, but delays
proceeding to next generation step.
model_kwargs: Dictionary of kwargs to pass into the model constructor on
instantiation. E.g. if `model` is `Models.SOBOL`, kwargs will be applied
as `Models.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, `get_sobol(
**model_kwargs)`. NOTE: if generation strategy is interrupted and
resumed from a stored snapshot and its last used model has state saved on
its generator runs, `model_kwargs` is updated with the state dict of the
model, retrieved from the last generator run of this generation strategy.
model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the
step's model's `gen` under the hood; `model_gen_kwargs` will be passed to
the model's `gen` like so: `model.gen(**model_gen_kwargs)`.
index: Index of this generation step, for use internally in `Generation
Strategy`. Do not assign as it will be reassigned when instantiating
`GenerationStrategy` with a list of its steps.
should_deduplicate: Whether to deduplicate the parameters of proposed arms
against those of previous arms via rejection sampling. If this is True,
the generation strategy will discard generator runs produced from the
generation step that has `should_deduplicate=True` if they contain arms
already present on the experiment and replace them with new generator runs.
If no generator run with entirely unique arms could be produced in 5
attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we
assume that the optimization converged when the model can no longer suggest
unique arms.
GenerationStep-s together¶A GenerationStrategy moves from one step to another when:
N=num_trials generator runs were produced and attached as trials to the experiment AND M=min_trials_observed have been completed and have data.Caveat: enforce_num_trials setting:
enforce_num_trials=True for a given generation step, if 1) is reached but 2) is not yet reached, the generation strategy will raise a DataRequiredError, indicating that more trials need to be completed before the next step.enforce_num_trials=False, the GS will continue producing generator runs from the current step until 2) is reached.max_parallelism enforcement¶Generation strategy can restrict the number of trials that can be ran simultaneously (to encourage sequential optimization, which benefits Bayesian optimization performance). When the parallelism limit is reached, a call to GenerationStrategy.gen will result in a MaxParallelismReachedException.
The correct way to handle this exception:
GenerationStep.max_parallelism is configured correctly for all steps in your generation strategy (to disable it completely, configure GenerationStep.max_parallelism=None),trial.mark_completed.When used through Service API or Scheduler, generation strategy will be automatically stored to SQL or JSON via specifying DBSettings to either AxClient or Scheduler (details in respective tutorials in the "Tutorials" page). Generation strategy can also be stored to SQL or JSON individually, as shown below.
More detail on SQL and JSON storage in Ax generally can be found in "Building Blocks of Ax" tutorial.
For SQL storage setup in Ax, read through the "Storage" documentation page.
Note that unlike an Ax experiment, a generation strategy does not have a name or another unique identifier. Therefore, a generation strategy is stored in association with experiment and can be retrieved by the associated experiment's name.
from ax.storage.sqa_store.save import save_generation_strategy, save_experiment
from ax.storage.sqa_store.load import load_experiment, load_generation_strategy_by_experiment_name
from ax.storage.sqa_store.db import init_engine_and_session_factory,get_engine, create_all_tables
from ax.storage.sqa_store.load import load_experiment
from ax.storage.sqa_store.save import save_experiment
init_engine_and_session_factory(url='sqlite:///foo2.db')
engine = get_engine()
create_all_tables(engine)
save_experiment(experiment)
save_generation_strategy(gs)
experiment = load_experiment(experiment_name=experiment.name)
gs = load_generation_strategy_by_experiment_name(
experiment_name=experiment.name,
experiment=experiment, # Can optionally specify experiment object to avoid loading it from database twice
)
gs
[INFO 04-26 20:20:22] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False /home/runner/work/Ax/Ax/ax/storage/sqa_store/load.py:231: SAWarning: TypeDecorator JSONEncodedText() will not produce a cache key because the ``cache_ok`` attribute is not set to True. This can have significant performance implications including some performance degradations in comparison to prior SQLAlchemy versions. Set this attribute to True if this type object's state is safe to use in a cache key, or False to disable this warning. (Background on this error at: https://sqlalche.me/e/14/cprf) .filter_by(name=experiment_name) [INFO 04-26 20:20:22] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.decoder import object_from_json
gs_json = object_to_json(gs) # Can be written to a file or string via `json.dump` etc.
gs = object_from_json(gs_json) # Decoded back from JSON (can be loaded from file, string via `json.load` etc.)
gs
[INFO 04-26 20:20:22] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])
Below is a list of important "gotchas" of using generation strategy (especially outside of the higher-level APIs like the Service API or the Scheduler):
GenerationStrategy.gen produces GeneratorRun-s, not trials¶Since GenerationStrategy.gen mimics ModelBridge.gen and allows for human-in-the-loop usage mode, a call to gen produces a GeneratorRun, which can then be added (or altered before addition or not added at all) to a Trial or BatchTrial on a given experiment. So it's important to add the generator run to a trial, since otherwise it will not be attached to the experiment on its own.
generator_run = gs.gen(
experiment=experiment, n=1, pending_observations=get_pending_observation_features(experiment)
)
experiment.new_trial(generator_run)
Trial(experiment_name='branin_test_experiment', index=1, status=TrialStatus.CANDIDATE, arm=Arm(name='1_0', parameters={'x1': -2.3552652867510915, 'x2': 1.2547599943354726}))
model_kwargs elements that do not define serialization logic in Ax¶Note that passing objects that are not yet serializable in Ax (e.g. a BoTorch Prior object) as part of GenerationStep.model_kwargs or GenerationStep.model_gen_kwargs will prevent correct generation strategy storage. If this becomes a problem, feel free to open an issue on our Github: https://github.com/facebook/Ax/issues to get help with adding storage support for a given object.
Models enum entries over a factory function?¶Models.GPEI captures all arguments to the model and model bridge and stores them on a generator runs, subsequently produced by the model. Since the capturing logic is part of Models.__call__ function, it is not present in a factory function. Furthermore, there is no safe and flexible way to serialize callables in Python.ModelBridge with an underlying Model instance based on them), it is not standard in terms of its inputs. Models introduces a standardized interface, making it easy to adapt any example to one's specific case.Models and natively supported in Ax?¶Please open a Github issue to request a new modeling setup in Ax (or for any other questions or requests).
Total runtime of script: 3.64 seconds.