Skip to content

Commit

Permalink
Domain propagation for comparison operators
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Jan 14, 2025
1 parent 33ffa01 commit b5ef7cd
Showing 1 changed file with 74 additions and 15 deletions.
89 changes: 74 additions & 15 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,22 +412,23 @@ def constrain(self, constraining_line):
results = Values(self.sid_line)
for value in self.values:
constraint_line = self.values[value]
results.set_value(Values.AND(constraining_line, constraint_line), self.sid_line, value)
results.set_value(self.sid_line, value,
Values.AND(constraining_line, constraint_line))
return results

def copy(self):
results = Values(self.sid_line)
for value in self.values:
constraint = self.values[value]
results.set_value(constraint, self.sid_line, value)
results.set_value(self.sid_line, value, constraint)
return results

def merge(self, values):
assert self.sid_line == values.sid_line
results = self.copy()
for value in values.values:
constraint = values.values[value]
results.set_value(constraint, values.sid_line, value)
results.set_value(values.sid_line, value, constraint)
return results

def get_expression(self):
Expand Down Expand Up @@ -466,7 +467,7 @@ def get_expression(self):
assert exp_line is not None
return exp_line

def set_value(self, constraint_line, sid_line, value):
def set_value(self, sid_line, value, constraint_line):
assert self.sid_line == sid_line
assert 0 <= value < 2**sid_line.size
if constraint_line != Constant.false:
Expand Down Expand Up @@ -540,7 +541,8 @@ def get_mapped_array_expression_for(self, index):

def get_values(self, step):
if 0 not in self.cache_values:
self.cache_values[0] = Values(self.sid_line).set_value(Constant.true, self.sid_line, self.value)
self.cache_values[0] = Values(self.sid_line).set_value(self.sid_line, self.value,
Constant.true)
return self.cache_values[0]

def get_z3(self):
Expand Down Expand Up @@ -916,7 +918,7 @@ def propagate(self, arg1_value, op_lambda):
results = Values(self.sid_line)
for value in arg1_value.values:
constraint_line = arg1_value.values[value]
results.set_value(constraint_line, self.sid_line, op_lambda(value))
results.set_value(self.sid_line, op_lambda(value), constraint_line)
return results

def get_values(self, step):
Expand Down Expand Up @@ -980,7 +982,8 @@ def propagate(self, arg1_value):
results = Values(self.sid_line)
for value in arg1_value.values:
constraint_line = arg1_value.values[value]
results.set_value(constraint_line, self.sid_line, (value & 2**(self.u + 1) - 1) >> self.l)
results.set_value(self.sid_line, (value & 2**(self.u + 1) - 1) >> self.l,
constraint_line)
return results

def get_values(self, step):
Expand Down Expand Up @@ -1039,13 +1042,13 @@ def propagate(self, arg1_value, op_lambda):
constraint_line = arg1_value.values[value]
if op_lambda == (lambda x: not x):
if value == 0:
results.set_value(constraint_line, self.sid_line, 1)
results.set_value(self.sid_line, 1, constraint_line)
else:
assert value == 1
results.set_value(constraint_line, self.sid_line, 0)
results.set_value(self.sid_line, 0, constraint_line)
else:
results.set_value(constraint_line, self.sid_line,
op_lambda(value) % 2**self.sid_line.size)
results.set_value(self.sid_line, op_lambda(value) % 2**self.sid_line.size,
constraint_line)
return results

def get_values(self, step):
Expand Down Expand Up @@ -1167,8 +1170,8 @@ def propagate(self, arg1_value, arg2_value):
constraint1_line = arg1_value.values[value1]
constraint2_line = arg2_value.values[value2]
assert 0 <= value1, value2 <= 1
results.set_value(Values.AND(constraint1_line, constraint2_line), self.sid_line,
value2 if value1 == 1 else 1)
results.set_value(self.sid_line, value2 if value1 == 1 else 1,
Values.AND(constraint1_line, constraint2_line))
return results

def get_z3(self):
Expand All @@ -1194,6 +1197,62 @@ def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
if not arg1_line.sid_line.match_sorts(arg2_line.sid_line):
raise model_error("compatible first and second operand sorts", line_no)

def propagate(self, arg1_value, arg2_value, op_lambda):
results = Values(self.sid_line)
for value1 in arg1_value.values:
for value2 in arg2_value.values:
constraint1_line = arg1_value.values[value1]
constraint2_line = arg2_value.values[value2]
if op_lambda(value1, value2):
results.set_value(self.sid_line, 1,
Values.AND(constraint1_line, constraint2_line))
else:
results.set_value(self.sid_line, 0,
Values.AND(constraint1_line, constraint2_line))
return results

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg2_value = self.arg2_line.get_values(step)
if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
if self.op == OP_EQ:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x == y)
elif self.op == OP_NEQ:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x != y)
elif self.op == OP_SGT:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) > arg2_value.sid_line.get_signed_value(y))
elif self.op == OP_UGT:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x > y)
elif self.op == OP_SGTE:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) >= arg2_value.sid_line.get_signed_value(y))
elif self.op == OP_UGTE:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x >= y)
elif self.op == OP_SLT:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) < arg2_value.sid_line.get_signed_value(y))
elif self.op == OP_ULT:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x < y)
elif self.op == OP_SLTE:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) <= arg2_value.sid_line.get_signed_value(y))
else:
assert self.op == OP_ULTE
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x <= y)
else:
arg1_value = arg1_value.get_expression()
arg2_value = arg2_value.get_expression()
self.cache_values[step] = self.copy(arg1_value, arg2_value)
return self.cache_values[step]

def get_z3(self):
if self.z3 is None:
if self.op == OP_EQ:
Expand Down Expand Up @@ -1384,8 +1443,8 @@ def propagate(self, arg1_value, arg2_value):
for value2 in arg2_value.values:
constraint1_line = arg1_value.values[value1]
constraint2_line = arg2_value.values[value2]
results.set_value(Values.AND(constraint1_line, constraint2_line), self.sid_line,
(value1 << arg2_value.sid_line.size) + value2)
results.set_value(self.sid_line, (value1 << arg2_value.sid_line.size) + value2,
Values.AND(constraint1_line, constraint2_line))
return results

def get_z3(self):
Expand Down

0 comments on commit b5ef7cd

Please sign in to comment.