Skip to content

Commit

Permalink
Remove hard-coded dtype from best_f buffers (#2725)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2725

Specifiying `dtype=float` (equivalent to `torch.float64`) causes issues if the user wants to use single precision. See #2724

Removing the `dtype` argument from `torch.as_tensor` will lead to using the existing dtype if the input is a tensor, and using `torch.get_default_dtype()` if the input is a python float or a list of floats. If the input is a numpy array, the corresponding torch dtype is used (e.g. `torch.float64` for `np.float64`).

Reviewed By: esantorella

Differential Revision: D69121327

fbshipit-source-id: 4a5639e0a541022333a8ca76d2715d4868134b24
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 4, 2025
1 parent a972ae1 commit c0db823
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 11 deletions.
5 changes: 1 addition & 4 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
from __future__ import annotations

from collections.abc import Callable

from copy import deepcopy

from functools import partial

from typing import TypeVar

import torch
Expand Down Expand Up @@ -215,7 +212,7 @@ def __init__(
tau_max=check_tau(tau_max, name="tau_max"),
fat=fat,
)
self.register_buffer("best_f", torch.as_tensor(best_f, dtype=float))
self.register_buffer("best_f", torch.as_tensor(best_f))
self.tau_relu = check_tau(tau_relu, name="tau_relu")

def _sample_forward(self, obj: Tensor) -> Tensor:
Expand Down
6 changes: 3 additions & 3 deletions botorch/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __init__(
constraints=constraints,
eta=eta,
)
self.register_buffer("best_f", torch.as_tensor(best_f, dtype=float))
self.register_buffer("best_f", torch.as_tensor(best_f))

def _sample_forward(self, obj: Tensor) -> Tensor:
r"""Evaluate qExpectedImprovement per sample on the candidate set `X`.
Expand Down Expand Up @@ -715,9 +715,9 @@ def __init__(
constraints=constraints,
eta=eta,
)
best_f = torch.as_tensor(best_f, dtype=float).unsqueeze(-1) # adding batch dim
best_f = torch.as_tensor(best_f).unsqueeze(-1) # adding batch dim
self.register_buffer("best_f", best_f)
self.register_buffer("tau", torch.as_tensor(tau, dtype=float))
self.register_buffer("tau", torch.as_tensor(tau))

def _sample_forward(self, obj: Tensor) -> Tensor:
r"""Evaluate qProbabilityOfImprovement per sample on the candidate set `X`.
Expand Down
2 changes: 1 addition & 1 deletion botorch_community/acquisition/rei.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
sample_dim = tuple(range(len(self.sample_shape) + 1))
self._sample_reduction = partial(logmeanexp, dim=sample_dim)

self.register_buffer("best_f", torch.as_tensor(best_f, dtype=float))
self.register_buffer("best_f", torch.as_tensor(best_f))
self.fat: bool = fat
self.tau_relu: float = check_tau(tau_relu, "tau_relu")
dim: int = X_dev.shape[1]
Expand Down
7 changes: 4 additions & 3 deletions test/acquisition/test_logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import sample_cached_cholesky
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior

from botorch.utils.transforms import standardize
from torch import Tensor

Expand Down Expand Up @@ -131,8 +130,10 @@ def test_q_log_expected_improvement(self):
self.assertIn(k, acqf._modules)
self.assertIn(k, log_acqf._modules)

res = acqf(X).item()
self.assertEqual(res, 0.0)
res = acqf(X)
self.assertEqual(res.dtype, dtype)
self.assertEqual(res.device.type, self.device.type)
self.assertEqual(res.item(), 0.0)
exp_log_res = log_acqf(X).exp().item()
# Due to the smooth approximation, the value at zero should be close to, but
# not exactly zero, and upper-bounded by the tau hyperparameter.
Expand Down

0 comments on commit c0db823

Please sign in to comment.