Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AST] Add math operators to simplify and test cases for them #501

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion heterocl/ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,41 @@ def simplify(expr):
index = struct.index
e = struct.tensor.fcompute(*index)[expr.field]
return sp.simplify(simplify(e))
if isinstance(expr, SelectOp): # pylint: disable=no-else-return
if isinstance(expr, SelectOp):
if simplify(expr.cond):
return sp.simplify(simplify(expr.true_value))
return sp.simplify(simplify(expr.false_value))
if isinstance(expr, MathExpOp):
expr = unwrap_sp(simplify(expr.expr))
return sp.exp(expr)
if isinstance(expr, MathPowOp):
lhs = unwrap_sp(simplify(expr.lhs))
rhs = unwrap_sp(simplify(expr.rhs))
return sp.Pow(lhs, rhs)
if isinstance(expr, MathLogOp):
expr = unwrap_sp(simplify(expr.expr))
return sp.log(expr)
if isinstance(expr, MathLog2Op):
expr = unwrap_sp(simplify(expr.expr))
return sp.log(expr, 2)
if isinstance(expr, MathLog10Op):
expr = unwrap_sp(simplify(expr.expr))
return sp.log(expr, 10)
if isinstance(expr, MathSqrtOp):
expr = unwrap_sp(simplify(expr.expr))
return sp.sqrt(expr)
if isinstance(expr, MathSinOp):
expr = unwrap_sp(simplify(expr.expr))
return sp.sin(expr)
if isinstance(expr, MathCosOp):
expr = unwrap_sp(simplify(expr.expr))
return sp.cos(expr)
# if isinstance(expr, MathTanOp):
# expr = unwrap_sp(simplify(expr.expr))
# return sp.tan(expr)
if isinstance(expr, MathTanhOp): # pylint: disable=no-else-return
expr = unwrap_sp(simplify(expr.expr))
return sp.tanh(expr)
else:
raise HCLError(f"Unsupported expression type: {type(expr)}")

Expand Down
264 changes: 264 additions & 0 deletions tests/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,267 @@ def kernel(A):
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b10101, 0b01110]


def test_math_exp():
def kernel(A):
lower_idx = hcl.exp(0)
lower_idx = hcl.cast(hcl.Index(), lower_idx)
upper_idx = hcl.exp(1.5)
upper_idx = hcl.cast(hcl.Index(), upper_idx) # e^1.5 ~4.482 -> 4

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10101100, 0b01100101])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b110, 0b010]


def test_math_pow():
def kernel(A):
a = hcl.scalar(0)
b = hcl.scalar(2)

lower_idx = hcl.power(b.v, a.v)
lower_idx = hcl.cast(hcl.Index(), lower_idx)
upper_idx = hcl.power(b.v, b.v)
upper_idx = hcl.cast(hcl.Index(), upper_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10101100, 0b01100101])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b110, 0b010]


def test_math_log_op():
def kernel(A):
a = hcl.scalar(1)

lower_idx = hcl.log(a.v)
lower_idx = hcl.cast(hcl.Index(), lower_idx)
upper_idx = hcl.log(12)
upper_idx = hcl.cast(hcl.Index(), upper_idx) + 3

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b01001111, 0b11101110])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b01111, 0b01110]


def test_math_log2_op():
def kernel(A):
a = hcl.scalar(16)

lower_idx = hcl.log2(2)
lower_idx = hcl.cast(hcl.Index(), lower_idx)
upper_idx = hcl.log2(a.v)
upper_idx = hcl.cast(hcl.Index(), upper_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b00001010, 0b11101110])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b101, 0b111]


def test_math_log10_op():
def kernel(A):
a = hcl.scalar(10)
b = hcl.scalar(10000)

lower_idx = hcl.log10(a.v)
lower_idx = hcl.cast(hcl.Index(), lower_idx)
upper_idx = hcl.log10(b.v)
upper_idx = hcl.cast(hcl.Index(), upper_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b01111100, 0b01001010])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b110, 0b101]


def test_math_sqrt1():
def kernel(A):
a = hcl.scalar(25)

upper_idx = hcl.sqrt(a.v)
upper_idx = hcl.cast(hcl.Index(), upper_idx)

B = hcl.compute(A.shape, lambda x: A[x][0:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10011101, 0b01001010])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b11101, 0b01010]


def test_math_sqrt2():
def kernel(A):
upper_idx = hcl.sqrt(30)
upper_idx = hcl.cast(hcl.Index(), upper_idx) # sqrt(30) ~ 5.477 -> 5

B = hcl.compute(A.shape, lambda x: A[x][0:upper_idx])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10011101, 0b01001010])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b11101, 0b01010]


def test_math_sin1():
def kernel(A):
lower_idx = hcl.sin(0)
lower_idx = hcl.cast(hcl.Index(), lower_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:3])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10011101, 0b10110001])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b101, 0b001]


def test_math_sin2():
def kernel(A):
lower_idx = hcl.sin(np.pi / 2)
lower_idx = hcl.cast(hcl.Index(), lower_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:5])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10011101, 0b10110001])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b1110, 0b1000]


def test_math_cos1():
def kernel(A):
lower_idx = hcl.cos(0)
lower_idx = hcl.cast(hcl.Index(), lower_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10011101, 0b10110001])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b110, 0b000]


def test_math_cos2():
def kernel(A):
lower_idx = hcl.cos(np.pi / 2)
lower_idx = hcl.cast(hcl.Index(), lower_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10011101, 0b10110001])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b1101, 0b0001]


# hcl.tan() not supported yet

# def test_math_tan():
# def kernel(A):
# lower_idx = hcl.tan(0)
# lower_idx = hcl.cast(hcl.Index(), lower_idx)

# B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4])
# return B

# A = hcl.placeholder((2,), "A")
# s = hcl.create_schedule([A], kernel)
# f = hcl.build(s)
# np_A = hcl.asarray([0b10011101, 0b00101100])
# np_B = hcl.asarray([0, 0])
# f(np_A, np_B)
# assert np_B.asnumpy().tolist() == [0b1101, 0b1100]


def test_math_tanh1():
def kernel(A):
lower_idx = hcl.tanh(0)
lower_idx = hcl.cast(hcl.Index(), lower_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10101010, 0b11001011])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b1010, 0b1011]


def test_math_tanh2():
def kernel(A):
a = hcl.scalar(8)

lower_idx = hcl.tanh(a.v)
lower_idx = hcl.cast(hcl.Index(), lower_idx)

B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4])
return B

A = hcl.placeholder((2,), "A")
s = hcl.create_schedule([A], kernel)
f = hcl.build(s)
np_A = hcl.asarray([0b10101011, 0b11001011])
np_B = hcl.asarray([0, 0])
f(np_A, np_B)
assert np_B.asnumpy().tolist() == [0b101, 0b101]