Source code for ax.benchmark.benchmark_test_functions.synthetic
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
import torch
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
[docs]
@dataclass(kw_only=True)
class IdentityTestFunction(BenchmarkTestFunction):
"""
Test function that returns the value of parameter "x0", ignoring any others.
"""
outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"])
n_steps: int = 1
# pyre-fixme[14]: Inconsistent override
[docs]
def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor:
"""
Return params["x0"] for each outcome for each time step.
Args:
params: A dictionary with key "x0".
"""
value = params["x0"]
return torch.full(
(len(self.outcome_names), self.n_steps), value, dtype=torch.float64
)