Skip to content

Commit

Permalink
Don't proliferate the metric switching
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo committed Sep 26, 2024
1 parent 0ebfa91 commit 71f1327
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public String getPhaseTypeString() {
// ************************************************************************
@Override
public void solve(SolverScope<Solution_> solverScope) {
var phaseScope = new ConstructionHeuristicPhaseScope<>(solverScope, phaseIndex);
var phaseScope = buildPhaseScope(solverScope, phaseIndex);
phaseStarted(phaseScope);

var solutionDescriptor = solverScope.getSolutionDescriptor();
Expand Down Expand Up @@ -94,6 +94,10 @@ public void solve(SolverScope<Solution_> solverScope) {
phaseEnded(phaseScope);
}

protected ConstructionHeuristicPhaseScope<Solution_> buildPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex) {
return new ConstructionHeuristicPhaseScope<>(solverScope, phaseIndex);
}

private void doStep(ConstructionHeuristicStepScope<Solution_> stepScope) {
var step = stepScope.getStep();
step.doMoveOnly(stepScope.getScoreDirector());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public void stepEnded(ConstructionHeuristicStepScope<Solution_> stepScope) {
@Override
public void addMove(ConstructionHeuristicMoveScope<Solution_> moveScope) {
selectedMoveCount++;
moveScope.getStepScope().incrementMoveEvaluationCount(moveScope.getMove());
moveScope.getStepScope().getPhaseScope()
.addMoveEvaluationCount(moveScope.getMove(), 1L);
checkPickEarly(moveScope);
if (maxScoreMoveScope == null || moveScope.getScore().compareTo(maxScoreMoveScope.getScore()) > 0) {
maxScoreMoveScope = moveScope;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
/**
* @param <Solution_> the solution type, the class with the {@link PlanningSolution} annotation
*/
public final class ConstructionHeuristicPhaseScope<Solution_> extends AbstractPhaseScope<Solution_> {
public class ConstructionHeuristicPhaseScope<Solution_> extends AbstractPhaseScope<Solution_> {

private ConstructionHeuristicStepScope<Solution_> lastCompletedStepScope;

Expand All @@ -25,8 +25,4 @@ public void setLastCompletedStepScope(ConstructionHeuristicStepScope<Solution_>
this.lastCompletedStepScope = lastCompletedStepScope;
}

// ************************************************************************
// Calculated methods
// ************************************************************************

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ public void expandNode(ExhaustiveSearchStepScope<Solution_> stepScope) {
manualEntityMimicRecorder.setRecordedEntity(expandingNode.getEntity());

int moveIndex = 0;
ExhaustiveSearchLayer moveLayer = stepScope.getPhaseScope().getLayerList().get(expandingNode.getDepth() + 1);
var phaseScope = stepScope.getPhaseScope();
ExhaustiveSearchLayer moveLayer = phaseScope.getLayerList().get(expandingNode.getDepth() + 1);
for (Move<?> move : moveSelector) {
ExhaustiveSearchNode moveNode = new ExhaustiveSearchNode(moveLayer, expandingNode);
moveIndex++;
Expand All @@ -114,10 +115,10 @@ public void expandNode(ExhaustiveSearchStepScope<Solution_> stepScope) {
// If the original value is null and the variable allows unassigned values,
// the move to null must be done too.
doMove(stepScope, moveNode);
stepScope.incrementMoveEvaluationCount(move);
phaseScope.addMoveEvaluationCount(move, 1);
// TODO in the lowest level (and only in that level) QuitEarly can be useful
// No QuitEarly because lower layers might be promising
stepScope.getPhaseScope().getSolverScope().checkYielding();
phaseScope.getSolverScope().checkYielding();
if (termination.isPhaseTerminated(stepScope.getPhaseScope())) {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import ai.timefold.solver.core.impl.constructionheuristic.DefaultConstructionHeuristicPhase;
import ai.timefold.solver.core.impl.constructionheuristic.scope.ConstructionHeuristicPhaseScope;
import ai.timefold.solver.core.impl.constructionheuristic.scope.ConstructionHeuristicStepScope;
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;

final class RuinRecreateConstructionHeuristicPhase<Solution_>
extends DefaultConstructionHeuristicPhase<Solution_>
Expand All @@ -13,6 +15,16 @@ final class RuinRecreateConstructionHeuristicPhase<Solution_>
super(builder);
}

@Override
protected void collectMetrics(AbstractStepScope<Solution_> stepScope) {
// Nested phase doesn't collect metrics.
}

@Override
protected ConstructionHeuristicPhaseScope<Solution_> buildPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex) {
return new RuinRecreateConstructionHeuristicPhaseScope<>(solverScope, phaseIndex);
}

@Override
protected void processWorkingSolutionDuringStep(ConstructionHeuristicStepScope<Solution_> stepScope) {
// Ruin and Recreate CH doesn't process the working solution, it is a nested phase.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public EntityPlacer<Solution_> getEntityPlacer() {

@Override
public DefaultConstructionHeuristicPhase<Solution_> build() {
disableMetricCollection();
return new RuinRecreateConstructionHeuristicPhase<>(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ai.timefold.solver.core.impl.heuristic.selector.move.generic;

import ai.timefold.solver.core.impl.constructionheuristic.scope.ConstructionHeuristicPhaseScope;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;

final class RuinRecreateConstructionHeuristicPhaseScope<Solution_> extends ConstructionHeuristicPhaseScope<Solution_> {

public RuinRecreateConstructionHeuristicPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex) {
super(solverScope, phaseIndex);
}

@Override
public void addChildThreadsMoveEvaluationCount(long addition) {
// Nested phase does not count moves.
}

@Override
public void addMoveEvaluationCount(Move<?> move, long count) {
// Nested phase does not count moves.
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public boolean supportsNeverEndingMoveSelector() {
@Override
public void addMove(LocalSearchMoveScope<Solution_> moveScope) {
selectedMoveCount++;
moveScope.getStepScope().incrementMoveEvaluationCount(moveScope.getMove());
moveScope.getStepScope().getPhaseScope().addMoveEvaluationCount(moveScope.getMove(), 1);
if (moveScope.getAccepted()) {
acceptedMoveCount++;
checkPickEarly(moveScope);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ public abstract class AbstractPhase<Solution_> implements Phase<Solution_> {
protected final boolean assertShadowVariablesAreNotStaleAfterStep;
protected final boolean triggerFirstInitializedSolutionEvent;

protected final boolean enableCollectMetrics;

/** Used for {@link #addPhaseLifecycleListener(PhaseLifecycleListener)}. */
protected PhaseLifecycleSupport<Solution_> phaseLifecycleSupport = new PhaseLifecycleSupport<>();

Expand All @@ -48,7 +46,6 @@ protected AbstractPhase(Builder<Solution_> builder) {
assertStepScoreFromScratch = builder.assertStepScoreFromScratch;
assertExpectedStepScore = builder.assertExpectedStepScore;
assertShadowVariablesAreNotStaleAfterStep = builder.assertShadowVariablesAreNotStaleAfterStep;
enableCollectMetrics = builder.enableCollectMetrics;
triggerFirstInitializedSolutionEvent = builder.triggerFirstInitializedSolutionEvent;
}

Expand Down Expand Up @@ -106,7 +103,6 @@ public void solvingEnded(SolverScope<Solution_> solverScope) {
@Override
public void phaseStarted(AbstractPhaseScope<Solution_> phaseScope) {
phaseScope.startingNow();
phaseScope.setEnableCollectMetrics(enableCollectMetrics);
phaseScope.reset();
solver.phaseStarted(phaseScope);
phaseTermination.phaseStarted(phaseScope);
Expand Down Expand Up @@ -160,15 +156,13 @@ protected <Score_ extends Score<Score_>> void predictWorkingStepScore(AbstractSt
@Override
public void stepEnded(AbstractStepScope<Solution_> stepScope) {
solver.stepEnded(stepScope);
if (enableCollectMetrics) {
collectMetrics(stepScope);
}
collectMetrics(stepScope);
phaseTermination.stepEnded(stepScope);
phaseLifecycleSupport.fireStepEnded(stepScope);
}

private void collectMetrics(AbstractStepScope<Solution_> stepScope) {
SolverScope<Solution_> solverScope = stepScope.getPhaseScope().getSolverScope();
protected void collectMetrics(AbstractStepScope<Solution_> stepScope) {
var solverScope = stepScope.getPhaseScope().getSolverScope();
if (solverScope.isMetricEnabled(SolverMetric.STEP_SCORE) && stepScope.getScore().isSolutionInitialized()) {
SolverMetric.registerScoreMetrics(SolverMetric.STEP_SCORE,
solverScope.getMonitoringTags(),
Expand Down Expand Up @@ -231,8 +225,6 @@ protected abstract static class Builder<Solution_> {
private boolean assertExpectedStepScore = false;
private boolean assertShadowVariablesAreNotStaleAfterStep = false;

private boolean enableCollectMetrics = true;

protected Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination) {
this(phaseIndex, false, logIndentation, phaseTermination);
}
Expand All @@ -257,10 +249,6 @@ public void setAssertShadowVariablesAreNotStaleAfterStep(boolean assertShadowVar
this.assertShadowVariablesAreNotStaleAfterStep = assertShadowVariablesAreNotStaleAfterStep;
}

public void disableMetricCollection() {
this.enableCollectMetrics = false;
}

protected abstract AbstractPhase<Solution_> build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.config.solver.monitoring.SolverMetric;
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;

Expand Down Expand Up @@ -34,8 +36,6 @@ public abstract class AbstractPhaseScope<Solution_> {

protected int bestSolutionStepIndex;

protected boolean enableCollectMetrics = true;

/**
* As defined by #AbstractPhaseScope(SolverScope, int, boolean)
* with the phaseSendingBestSolutionEvents parameter set to true.
Expand Down Expand Up @@ -93,20 +93,6 @@ public void setBestSolutionStepIndex(int bestSolutionStepIndex) {

public abstract AbstractStepScope<Solution_> getLastCompletedStepScope();

/**
* @return true, if the metrics collection, such as
* {@link ai.timefold.solver.core.config.solver.monitoring.SolverMetric#MOVE_COUNT_PER_TYPE MOVE_COUNT_PER_TYPE},
* is enabled.
* This is disabled for nested phases, such as Construction heuristics in Ruin and Recreate.
*/
public boolean isMetricCollectionEnabled() {
return enableCollectMetrics;
}

public void setEnableCollectMetrics(boolean enableCollectMetrics) {
this.enableCollectMetrics = enableCollectMetrics;
}

// ************************************************************************
// Calculated methods
// ************************************************************************
Expand Down Expand Up @@ -159,6 +145,23 @@ public void addChildThreadsMoveEvaluationCount(long addition) {
childThreadsMoveEvaluationCount += addition;
}

public void addMoveEvaluationCount(Move<?> move, long count) {
solverScope.addMoveEvaluationCount(1);
addMoveEvaluationCountPerType(move, count);
}

public void addMoveEvaluationCountPerType(Move<?> move, long count) {
if (solverScope.isMetricEnabled(SolverMetric.MOVE_COUNT_PER_TYPE)) {
solverScope.addMoveEvaluationCountPerType(move.getSimpleMoveTypeDescription(), count);
}
}

public void addMoveEvaluationCountPerType(String moveDescription, long count) {
if (solverScope.isMetricEnabled(SolverMetric.MOVE_COUNT_PER_TYPE)) {
solverScope.addMoveEvaluationCountPerType(moveDescription, count);
}
}

public long getPhaseScoreCalculationCount() {
return endingScoreCalculationCount - startingScoreCalculationCount + childThreadsScoreCalculationCount;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.config.solver.monitoring.SolverMetric;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;

/**
Expand Down Expand Up @@ -46,15 +44,6 @@ public void setBestScoreImproved(Boolean bestScoreImproved) {
this.bestScoreImproved = bestScoreImproved;
}

public void incrementMoveEvaluationCount(Move<?> move) {
if (getPhaseScope().isMetricCollectionEnabled()) {
getPhaseScope().getSolverScope().addMoveEvaluationCount(1L);
if (getPhaseScope().getSolverScope().isMetricEnabled(SolverMetric.MOVE_COUNT_PER_TYPE)) {
getPhaseScope().getSolverScope().incrementMoveEvaluationCountPerType(move);
}
}
}

// ************************************************************************
// Calculated methods
// ************************************************************************
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ai.timefold.solver.core.api.solver.Solver;
import ai.timefold.solver.core.config.solver.monitoring.SolverMetric;
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope;
import ai.timefold.solver.core.impl.score.definition.ScoreDefinition;
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
Expand Down Expand Up @@ -376,10 +375,6 @@ public void destroyYielding() {
}
}

public void incrementMoveEvaluationCountPerType(Move<?> move) {
addMoveEvaluationCountPerType(move.getSimpleMoveTypeDescription(), 1L);
}

public void addMoveEvaluationCountPerType(String moveType, long count) {
moveEvaluationCountPerTypeMap.compute(moveType, (key, counter) -> {
if (counter == null) {
Expand Down

0 comments on commit 71f1327

Please sign in to comment.