Skip to content

Commit

Permalink
ENH: Implement size_per_learner in SlurmExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Dec 20, 2024
1 parent b4d64ac commit 90a457f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
21 changes: 16 additions & 5 deletions adaptive_scheduler/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ class SlurmExecutor(AdaptiveSchedulerExecutorBase):
_sequences: dict[Callable[..., Any], list[Any]] = field(default_factory=dict)
_sequence_mapping: dict[Callable[..., Any], int] = field(default_factory=dict)
_run_manager: adaptive_scheduler.RunManager | None = None
size_per_learner: int | None = None

def __post_init__(self) -> None:
if self.folder is None:
Expand All @@ -388,11 +389,21 @@ def _to_learners(self) -> tuple[list[SequenceLearner], list[Path]]:
learners = []
fnames = []
for func, args_kwargs_list in self._sequences.items():
learner = SequenceLearner(_SerializableFunctionSplatter(func), args_kwargs_list)
learners.append(learner)
assert isinstance(self.folder, Path)
name = func.__name__ if hasattr(func, "__name__") else ""
fnames.append(self.folder / f"{name}-{uuid.uuid4().hex}.pickle")
# Chunk the sequence if size_per_learner is specified
if self.size_per_learner is not None:
chunked_args_kwargs_list = [
args_kwargs_list[i : i + self.size_per_learner]
for i in range(0, len(args_kwargs_list), self.size_per_learner)
]
else:
chunked_args_kwargs_list = [args_kwargs_list]

for chunk in chunked_args_kwargs_list:
learner = SequenceLearner(_SerializableFunctionSplatter(func), chunk)
learners.append(learner)
assert isinstance(self.folder, Path)
name = func.__name__ if hasattr(func, "__name__") else ""
fnames.append(self.folder / f"{name}-{uuid.uuid4().hex}.pickle")
return learners, fnames

def finalize(self, *, start: bool = True) -> adaptive_scheduler.RunManager:
Expand Down
2 changes: 1 addition & 1 deletion adaptive_scheduler/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _sort_key(value: tuple[float | str, str]) -> tuple[float | int, str]:

def _vec_timedelta(ts: pd.Timestamp) -> str:
now = np.datetime64(datetime.now()) # noqa: DTZ005
dt = np.timedelta64(now - ts, "s") # type: ignore[operator]
dt = np.timedelta64(now - ts, "s") # type: ignore[operator, call-overload]
return f"{dt} ago"

mapping = {
Expand Down

0 comments on commit 90a457f

Please sign in to comment.