Skip to content

Commit

Permalink
Domain propagation for computation operators
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Jan 15, 2025
1 parent a085665 commit 78da0a8
Showing 1 changed file with 54 additions and 6 deletions.
60 changes: 54 additions & 6 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,12 +1258,8 @@ def propagate(self, arg1_value, arg2_value, op_lambda):
for value2 in arg2_value.values:
constraint1_line = arg1_value.values[value1]
constraint2_line = arg2_value.values[value2]
if op_lambda(value1, value2):
results.set_value(self.sid_line, 1,
Values.AND(constraint1_line, constraint2_line))
else:
results.set_value(self.sid_line, 0,
Values.AND(constraint1_line, constraint2_line))
results.set_value(self.sid_line, 1 if op_lambda(value1, value2) else 0,
Values.AND(constraint1_line, constraint2_line))
return results

def get_values(self, step):
Expand Down Expand Up @@ -1428,6 +1424,58 @@ def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
if not arg1_line.sid_line.match_sorts(arg2_line.sid_line):
raise model_error("compatible first and second operand sorts", line_no)

def propagate(self, arg1_value, arg2_value, op_lambda):
results = Values(self.sid_line)
for value1 in arg1_value.values:
for value2 in arg2_value.values:
constraint1_line = arg1_value.values[value1]
constraint2_line = arg2_value.values[value2]
results.set_value(self.sid_line, op_lambda(value1, value2) % 2**self.sid_line.size,
Values.AND(constraint1_line, constraint2_line))
return results

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg2_value = self.arg2_line.get_values(step)
if isinstance(arg1_value, Values) and isinstance(arg2_value, Values):
if self.op == OP_SLL:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x << y)
elif self.op == OP_SRL:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x >> y)
elif self.op == OP_SRA:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) >> y)
elif self.op == OP_ADD:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x + y)
elif self.op == OP_SUB:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x - y)
elif self.op == OP_MUL:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x * y)
elif self.op == OP_SDIV:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) / arg2_value.sid_line.get_signed_value(y))
elif self.op == OP_UDIV:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x / y)
elif self.op == OP_SREM:
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: arg1_value.sid_line.get_signed_value(x) % arg2_value.sid_line.get_signed_value(y))
else:
assert self.op == OP_UREM
self.cache_values[step] = self.propagate(arg1_value, arg2_value,
lambda x, y: x % y)
else:
arg1_value = arg1_value.get_expression()
arg2_value = arg2_value.get_expression()
self.cache_values[step] = self.copy(arg1_value, arg2_value)
return self.cache_values[step]

def get_z3(self):
if self.z3 is None:
if self.op == OP_SLL:
Expand Down

0 comments on commit 78da0a8

Please sign in to comment.