"""Expression AST Node in TVM.
User do not need to deal with expression AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each expression node have subfields that can be visited from python side.
For example, you can use addexp.a to get the left operand of an Add node.
.. code-block:: python
x = tvm.var("n")
y = x + 2
assert(isinstance(y, tvm.expr.Add))
assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make
from . import _api_internal
from ..debug import APIError
[docs]class ExprOp(object):
def __add__(self, other):
return _make.Add(self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return _make.Sub(self, other)
def __rsub__(self, other):
return _make.Sub(other, self)
def __mul__(self, other):
return _make.Mul(self, other)
def __rmul__(self, other):
return _make.Mul(other, self)
def __div__(self, other):
return _make.Div(self, other)
def __rdiv__(self, other):
return _make.Div(other, self)
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __floordiv__(self, other):
return self.__div__(other)
def __rfloordiv__(self, other):
return self.__rdiv__(other)
def __mod__(self, other):
return _make.Mod(self, other)
def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
return self.__mul__(neg_one)
def __lshift__(self, other):
if "float" in self.dtype:
raise APIError("Cannot perform shift with float")
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
def __rshift__(self, other):
if "float" in self.dtype:
raise APIError("Cannot perform shift with float")
return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0)
def __and__(self, other):
if "float" in self.dtype:
raise APIError("Cannot perform bitwise and with float")
return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0)
def __or__(self, other):
if "float" in self.dtype:
raise APIError("Cannot perform bitwise or with float")
return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0)
def __xor__(self, other):
if "float" in self.dtype:
raise APIError("Cannot perform bitwise xor with float")
return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0)
def __invert__(self):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
def __lt__(self, other):
return _make.LT(self, other)
def __le__(self, other):
return _make.LE(self, other)
def __eq__(self, other):
return EqualOp(self, other)
def __ne__(self, other):
return NotEqualOp(self, other)
def __gt__(self, other):
return _make.GT(self, other)
def __ge__(self, other):
return _make.GE(self, other)
def __getitem__(self, indices):
if isinstance(indices, slice):
return _make.GetSlice(self, indices.start, indices.stop)
else:
return _make.GetBit(self, indices)
def __setitem__(self, indices, expr):
raise APIError("Cannot set bit/slice of an expression")
def __nonzero__(self):
raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
"use tvm.all / tvm.any instead")
def __bool__(self):
return self.__nonzero__()
[docs] def equal(self, other):
"""Build an equal check expression with other expr.
Parameters
----------
other : Expr
The other expression
Returns
-------
ret : Expr
The equality expression.
"""
return _make.EQ(self, other)
[docs] def astype(self, dtype):
"""Cast the expression to other type.
Parameters
----------
dtype : str
The type of new expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.static_cast(dtype, self)
[docs]class EqualOp(NodeGeneric, ExprOp):
"""Deferred equal operator.
This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
# This class is not manipulated by C++. So use python's identity check function is sufficient
same_as = object.__eq__
def __init__(self, a, b):
self.a = a
self.b = b
def __nonzero__(self):
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
[docs] def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)
[docs]class NotEqualOp(NodeGeneric, ExprOp):
"""Deferred NE operator.
This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
# This class is not manipulated by C++. So use python's identity check function is sufficient
same_as = object.__eq__
def __init__(self, a, b):
self.a = a
self.b = b
def __nonzero__(self):
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
[docs] def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)
[docs]class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions"""
# In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__
[docs]class ConstExpr(Expr):
pass
[docs]class BinaryOpExpr(Expr):
pass
[docs]class CmpExpr(Expr):
pass
[docs]class LogicalExpr(Expr):
pass
[docs]@register_node("Variable")
class Var(Expr):
"""Symbolic variable."""
pass
[docs]@register_node
class Reduce(Expr):
pass
[docs]@register_node
class FloatImm(ConstExpr):
pass
[docs]@register_node
class IntImm(ConstExpr):
pass
[docs]@register_node
class UIntImm(ConstExpr):
pass
[docs]@register_node
class StringImm(ConstExpr):
pass
[docs]@register_node
class Cast(Expr):
pass
[docs]@register_node
class Add(BinaryOpExpr):
pass
[docs]@register_node
class Sub(BinaryOpExpr):
pass
[docs]@register_node
class Mul(BinaryOpExpr):
pass
[docs]@register_node
class Div(BinaryOpExpr):
pass
[docs]@register_node
class Mod(BinaryOpExpr):
pass
[docs]@register_node
class Min(BinaryOpExpr):
pass
[docs]@register_node
class Max(BinaryOpExpr):
pass
[docs]@register_node
class EQ(CmpExpr):
pass
[docs]@register_node
class NE(CmpExpr):
pass
[docs]@register_node
class LT(CmpExpr):
pass
[docs]@register_node
class LE(CmpExpr):
pass
[docs]@register_node
class GT(CmpExpr):
pass
[docs]@register_node
class GE(CmpExpr):
pass
[docs]@register_node
class And(LogicalExpr):
pass
[docs]@register_node
class Or(LogicalExpr):
pass
[docs]@register_node
class Not(LogicalExpr):
pass
[docs]@register_node
class Select(Expr):
pass
[docs]@register_node
class Load(Expr):
pass
[docs]@register_node
class Ramp(Expr):
pass
[docs]@register_node
class Broadcast(Expr):
pass
[docs]@register_node
class Shuffle(Expr):
pass
[docs]@register_node
class Call(Expr):
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
[docs]@register_node
class Let(Expr):
pass
[docs]@register_node
class GetBit(Expr):
pass
[docs]@register_node
class GetSlice(Expr):
pass
[docs]@register_node
class SetBit(Expr):
pass
[docs]@register_node
class SetSlice(Expr):
pass
[docs]@register_node
class Quantize(Expr):
pass
[docs]@register_node
class KernelExpr(Expr):
pass
[docs]@register_node
class StreamExpr(Expr):
FIFO = 0
DoubleBuffer = 1