Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make possible to rewrite discrete expressions with rules #983

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@
import org.chocosolver.solver.variables.IntVar;
import org.chocosolver.util.tools.MathUtils;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.OptionalInt;
import java.util.*;

/**
*
* arithmetic expression
* <p>
* Project: choco-solver.
Expand All @@ -36,7 +32,7 @@ public interface ArExpression {
/**
* List of available operator for arithmetic expression
*/
enum Operator {
enum Operator implements ExpOperator {
/**
* negation operator
*/
Expand Down Expand Up @@ -143,13 +139,13 @@ int eval(int i1) {

@Override
int eval(int i1, int i2) {
if(i2 == 0){
if(i1>0) {
if (i2 == 0) {
if (i1 > 0) {
return Integer.MAX_VALUE;
}else{
} else {
return Integer.MIN_VALUE;
}
}else {
} else {
return i1 / i2;
}
}
Expand All @@ -170,13 +166,13 @@ int eval(int i1) {

@Override
int eval(int i1, int i2) {
if(i2 == 0){
if(i1>0) {
if (i2 == 0) {
if (i1 > 0) {
return Integer.MAX_VALUE;
}else{
} else {
return Integer.MIN_VALUE;
}
}else {
} else {
return i1 % i2;
}
}
Expand Down Expand Up @@ -262,7 +258,7 @@ int identity() {
return Integer.MIN_VALUE;
}
},
NOP{
NOP {
@Override
int eval(int i1) {
return 0;
Expand Down Expand Up @@ -307,19 +303,20 @@ int identity() {
/**
* @return an {@link OptionalInt} which contains an {@code int} this expression is a primitive.
*/
default OptionalInt primitive(){
default OptionalInt primitive() {
return OptionalInt.empty();
}

/**
* @return <tt>true</tt> if this expression is a leaf, ie a variable, <tt>false</tt> otherwise
*/
default boolean isExpressionLeaf(){
default boolean isExpressionLeaf() {
return false;
}

/**
* Extract the variables from this expression
*
* @param variables set of variables
*/
default void extractVar(HashSet<IntVar> variables) {
Expand All @@ -334,29 +331,69 @@ default void extractVar(HashSet<IntVar> variables) {

/**
* @param values int values to evaluate
* @param map mapping between variables of the topmost expression and position in <i>values</i>
* @param map mapping between variables of the topmost expression and position in <i>values</i>
* @return an evaluation of this expression with a tuple
*/
@SuppressWarnings("SuspiciousMethodCalls")
default int ieval(int[] values, Map<IntVar, Integer> map){
default int ieval(int[] values, Map<IntVar, Integer> map) {
assert this instanceof IntVar;
return values[map.get(this)];
}

/**
* @return the child of this expression, or null if thid
*/
default int getNoChild(){
default int getNoChild() {
return 0;
}

/**
* @return the child of this expression, or null if thid
*/
default ArExpression[] getExpressionChild(){
default ArExpression[] getExpressionChild() {
return NO_CHILD;
}

/**
* @return the operator of this expression
*/
default ExpOperator getOperator() {
return Operator.NOP;
}

/**
* Replace the sub-expression at position <i>idx</i> in the expression by <i>e</i>.
*
* @param idx index of the expression to replace
* @param e the new expression
* @implSpec This method is only supposed to be used by {{@link #rewrite(List)}
*/
default void set(int idx, ArExpression e) {
}

/**
* Rewrite the current expression by applying the rewriting rules in parameters.
* @param rules list of rules to applied sequentially
* @return the rewritten expression
* @implNote The rules are applied sequentially according to the order defined by the list.
* A rule is first applied to each sub-expressions, recursively, before being applied to the top expression.
* For a given rule to be applied twice, put it twice in the list.
* If a rule rewrite an expression, the replacement is made in place ({@link #set(int, ArExpression)}.
*/
@SuppressWarnings("unchecked")
default <E extends ArExpression> E rewrite(List<Rule<ArExpression>> rules) {
E rewritten = (E) this;
for (Rule<ArExpression> rule : rules) {
ArExpression[] children = this.getExpressionChild();
for (int i = 0; i < children.length; i++) {
rewritten.set(i, children[i].rewrite(Collections.singletonList(rule)));
}
if (rule.predicate.test(rewritten)) {
rewritten = (E) rule.rewriter.apply(rewritten);
}
}
return rewritten;
}

/**
* @return return the expression "-x" where this is "x"
*/
Expand Down Expand Up @@ -384,7 +421,7 @@ default ArExpression add(int y) {
* @return return the expression "x + y" where this is "x"
*/
default ArExpression add(ArExpression y) {
if(y.primitive().isPresent()){
if (y.primitive().isPresent()) {
return add(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.ADD, this, y);
Expand Down Expand Up @@ -431,7 +468,7 @@ default ArExpression mul(int y) {
* @return return the expression "x * y" where this is "x"
*/
default ArExpression mul(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return mul(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.MUL, this, y);
Expand All @@ -458,7 +495,7 @@ default ArExpression div(int y) {
* @return return the expression "x / y" where this is "x"
*/
default ArExpression div(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return div(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.DIV, this, y);
Expand All @@ -477,7 +514,7 @@ default ArExpression mod(int y) {
* @return return the expression "x % y" where this is "x"
*/
default ArExpression mod(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return mod(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.MOD, this, y);
Expand All @@ -503,7 +540,7 @@ default ArExpression pow(int y) {
* @return return the expression "x + y" where this is "x"
*/
default ArExpression pow(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return pow(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.POW, this, y);
Expand All @@ -522,7 +559,7 @@ default ArExpression min(int y) {
* @return return the expression "min(x, y)" where this is "x"
*/
default ArExpression min(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return min(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.MIN, this, y);
Expand All @@ -549,7 +586,7 @@ default ArExpression max(int y) {
* @return return the expression "max(x, y)" where this is "x"
*/
default ArExpression max(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return max(y.primitive().getAsInt());
}
return new BiArExpression(ArExpression.Operator.MAX, this, y);
Expand Down Expand Up @@ -592,7 +629,7 @@ default ReExpression lt(int y) {
* @return return the expression "x < y" where this is "x"
*/
default ReExpression lt(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return lt(y.primitive().getAsInt());
}
return new BiReExpression(ReExpression.Operator.LT, this, y);
Expand Down Expand Up @@ -630,7 +667,7 @@ default ReExpression gt(int y) {
* @return return the expression "x > y" where this is "x"
*/
default ReExpression gt(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return gt(y.primitive().getAsInt());
}
return new BiReExpression(ReExpression.Operator.GT, this, y);
Expand Down Expand Up @@ -668,7 +705,7 @@ default ReExpression ne(int y) {
* @return return the expression "x =/= y" where this is "x"
*/
default ReExpression ne(ArExpression y) {
if(y.primitive().isPresent()) {
if (y.primitive().isPresent()) {
return ne(y.primitive().getAsInt());
}
return new BiReExpression(ReExpression.Operator.NE, this, y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ public class BiArExpression implements ArExpression {
/**
* The first expression this expression relies on
*/
private final ArExpression e1;
private ArExpression e1;
/**
* The second expression this expression relies on
*/
private final ArExpression e2;
private ArExpression e2;

/**
* Builds a binary expression
Expand Down Expand Up @@ -110,10 +110,10 @@ public IntVar intVar() {
bounds = VariableUtils.boundsForPow(v1, v2);
me = model.intVar(model.generateName("pow_exp_"), bounds[0], bounds[1]);
Tuples tuples = new Tuples(true);
for(int val1 : v1){
for(int val2 : v2){
int res = (int)Math.pow(val1, val2);
if(me.contains(res)) {
for (int val1 : v1) {
for (int val2 : v2) {
int res = (int) Math.pow(val1, val2);
if (me.contains(res)) {
tuples.add(val1, val2, res);
}
}
Expand Down Expand Up @@ -152,6 +152,17 @@ public ArExpression[] getExpressionChild() {
return new ArExpression[]{e1, e2};
}

@Override
public ExpOperator getOperator() {
return op;
}

@Override
public void set(int idx, ArExpression e) {
if (idx == 0) this.e1 = e;
if (idx == 1) this.e2 = e;
}

@Override
public String toString() {
return op.name() + "(" + e1.toString() + "," + e2.toString() + ")";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* This file is part of choco-solver, http://choco-solver.org/
*
* Copyright (c) 2022, IMT Atlantique. All rights reserved.
*
* Licensed under the BSD 4-clause license.
*
* See LICENSE file in the project root for full license information.
*/
package org.chocosolver.solver.expression.discrete.arithmetic;

/**
* An interface that is implemented by expressions to defined specific operators.
* This interface is only required to ease rewriting expressions.
* <br/>
*
* @author Charles Prud'homme
* @since 30/11/2022
*/
public interface ExpOperator {
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ public ArExpression[] getExpressionChild() {
return es;
}

@Override
public ExpOperator getOperator() {
return op;
}

@Override
public void set(int idx, ArExpression e) {
if (idx >= 0 && idx < es.length) {
this.es[idx] = e;
}
}

@Override
public String toString() {
return op.name() + "(" + es[0].toString() + ",... ," + es[es.length - 1].toString() + ")";
Expand Down
Loading