From 968a481d71790c54b1ca0f477d12d7b03a5a949a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Mon, 11 Sep 2023 15:59:17 +0200 Subject: [PATCH 1/2] add surrogates.py --- .../surrogates/botorch_surrogates.py | 96 +------------------ bofire/data_models/surrogates/surrogates.py | 95 ++++++++++++++++++ 2 files changed, 99 insertions(+), 92 deletions(-) create mode 100644 bofire/data_models/surrogates/surrogates.py diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index e085f92a4..478615387 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -1,11 +1,5 @@ -import itertools -from typing import List, Union +from typing import List, Literal, Union -from pydantic import validator - -from bofire.data_models.base import BaseModel -from bofire.data_models.domain.api import Inputs, Outputs -from bofire.data_models.features.api import TInputTransformSpecs from bofire.data_models.surrogates.empirical import EmpiricalSurrogate from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.data_models.surrogates.mixed_single_task_gp import ( @@ -14,6 +8,7 @@ from bofire.data_models.surrogates.mlp import MLPEnsemble from bofire.data_models.surrogates.random_forest import RandomForestSurrogate from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate +from bofire.data_models.surrogates.surrogates import Surrogates from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate AnyBotorchSurrogate = Union[ @@ -27,93 +22,10 @@ ] -class BotorchSurrogates(BaseModel): +class BotorchSurrogates(Surrogates): """ "List of botorch surrogates. Behaves similar to a Surrogate.""" + type: Literal["BotorchSurrogates"] = "BotorchSurrogates" surrogates: List[AnyBotorchSurrogate] - - @property - def input_preprocessing_specs(self) -> TInputTransformSpecs: - return { - key: value - for model in self.surrogates - for key, value in model.input_preprocessing_specs.items() - } - - @property - def outputs(self) -> Outputs: - return Outputs( - features=list( - itertools.chain.from_iterable( - [model.outputs.get() for model in self.surrogates] # type: ignore - ) - ) - ) - - def _check_compability(self, inputs: Inputs, outputs: Outputs): - # TODO: add sync option - used_output_feature_keys = self.outputs.get_keys() - if sorted(used_output_feature_keys) != sorted(outputs.get_keys()): - raise ValueError("Output features do not match.") - used_feature_keys = [] - for i, model in enumerate(self.surrogates): - if len(model.inputs) > len(inputs): - raise ValueError( - f"Model with index {i} has more features than acceptable." - ) - for feat in model.inputs: - try: - other_feat = inputs.get_by_key(feat.key) - except KeyError: - raise ValueError(f"Feature {feat.key} not found.") - # now compare the features - # TODO: make more sohisticated comparisons based on the type - # has to to be implemented in features, for the start - # we go with __eq__ - if feat != other_feat: - raise ValueError(f"Features with key {feat.key} are incompatible.") - if feat.key not in used_feature_keys: - used_feature_keys.append(feat.key) - if len(used_feature_keys) != len(inputs): - raise ValueError("Unused features are present.") - - @validator("surrogates") - def validate_surrogates(cls, v, values): - # validate that all surrogates are single output surrogates - # TODO: this restriction has to be removed at some point - for model in v: - if len(model.outputs) != 1: - raise ValueError("Only single output surrogates allowed.") - # check that the output feature keys are distinctw - used_output_feature_keys = list( - itertools.chain.from_iterable([model.outputs.get_keys() for model in v]) - ) - if len(set(used_output_feature_keys)) != len(used_output_feature_keys): - raise ValueError("Output feature keys are not unique across surrogates.") - # get the feature keys present in all surrogates - used_feature_keys = [] - for model in v: - for key in model.inputs.get_keys(): - if key not in used_feature_keys: - used_feature_keys.append(key) - # check that the features and preprocessing steps are equal trough the surrogates - for key in used_feature_keys: - features = [ - model.inputs.get_by_key(key) - for model in v - if key in model.inputs.get_keys() - ] - preproccessing = [ - model.input_preprocessing_specs[key] - for model in v - if key in model.input_preprocessing_specs - ] - if all(features) is False: - raise ValueError(f"Features with key {key} are incompatible.") - if all(i == preproccessing[0] for i in preproccessing) is False: - raise ValueError( - f"Preprocessing steps for features with {key} are incompatible." - ) - return v diff --git a/bofire/data_models/surrogates/surrogates.py b/bofire/data_models/surrogates/surrogates.py new file mode 100644 index 000000000..d2fe3cfb8 --- /dev/null +++ b/bofire/data_models/surrogates/surrogates.py @@ -0,0 +1,95 @@ +import itertools +from typing import Annotated, List + +from pydantic import Field, validator + +from bofire.data_models.base import BaseModel +from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.features.api import TInputTransformSpecs + + +class Surrogates(BaseModel): + type: str + surrogates: Annotated[List, Field(min_items=1)] + + @property + def input_preprocessing_specs(self) -> TInputTransformSpecs: + return { + key: value + for model in self.surrogates + for key, value in model.input_preprocessing_specs.items() + } + + @property + def outputs(self) -> Outputs: + return Outputs( + features=list( + itertools.chain.from_iterable( + [model.outputs.get() for model in self.surrogates] # type: ignore + ) + ) + ) + + @validator("surrogates") + def validate_surrogates(cls, v, values): + # validate that all surrogates are single output surrogates + for model in v: + if len(model.outputs) != 1: + raise ValueError("Only single output surrogates allowed.") + # check that the output feature keys are distinct + used_output_feature_keys = list( + itertools.chain.from_iterable([model.outputs.get_keys() for model in v]) + ) + if len(set(used_output_feature_keys)) != len(used_output_feature_keys): + raise ValueError("Output feature keys are not unique across surrogates.") + # get the feature keys present in all surrogates + used_feature_keys = [] + for model in v: + for key in model.inputs.get_keys(): + if key not in used_feature_keys: + used_feature_keys.append(key) + # check that the features and preprocessing steps are equal trough the surrogates + for key in used_feature_keys: + features = [ + model.inputs.get_by_key(key) + for model in v + if key in model.inputs.get_keys() + ] + preproccessing = [ + model.input_preprocessing_specs[key] + for model in v + if key in model.input_preprocessing_specs + ] + if all(features) is False: + raise ValueError(f"Features with key {key} are incompatible.") + if all(i == preproccessing[0] for i in preproccessing) is False: + raise ValueError( + f"Preprocessing steps for features with {key} are incompatible." + ) + return v + + def _check_compability(self, inputs: Inputs, outputs: Outputs): + used_output_feature_keys = self.outputs.get_keys() + if sorted(used_output_feature_keys) != sorted(outputs.get_keys()): + raise ValueError("Output features do not match.") + used_feature_keys = [] + for i, model in enumerate(self.surrogates): + if len(model.inputs) > len(inputs): + raise ValueError( + f"Model with index {i} has more features than acceptable." + ) + for feat in model.inputs: + try: + other_feat = inputs.get_by_key(feat.key) + except KeyError: + raise ValueError(f"Feature {feat.key} not found.") + # now compare the features + # TODO: make more sohisticated comparisons based on the type + # has to to be implemented in features, for the start + # we go with __eq__ + if feat != other_feat: + raise ValueError(f"Features with key {feat.key} are incompatible.") + if feat.key not in used_feature_keys: + used_feature_keys.append(feat.key) + if len(used_feature_keys) != len(inputs): + raise ValueError("Unused features are present.") From 9b62237e8416875d73c1862f81a34186d2373f0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Mon, 11 Sep 2023 16:07:31 +0200 Subject: [PATCH 2/2] remove type attribute --- bofire/data_models/surrogates/botorch_surrogates.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index 478615387..de750bf60 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Union +from typing import List, Union from bofire.data_models.surrogates.empirical import EmpiricalSurrogate from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate @@ -27,5 +27,4 @@ class BotorchSurrogates(Surrogates): Behaves similar to a Surrogate.""" - type: Literal["BotorchSurrogates"] = "BotorchSurrogates" surrogates: List[AnyBotorchSurrogate]