"""HeteroCL imperative DSL."""
#pylint: disable=too-many-arguments,missing-docstring
from .tvm import make as _make
from .tvm import stmt as _stmt
from .tvm import ir_pass as _pass
from .tvm._api_internal import _IterVar, _Var
from .tvm.ir_builder import WithScope
from .api import placeholder
from .debug import DSLError, APIError
from .schedule import Stage
from .module import Module
from . import util
[docs]def and_(*args):
"""Compute the logic AND between expressions.
If there is only one argument, itself is returned.
Parameters
----------
args : list of Expr
A list of expression to be computed
Returns
-------
Expr
Examples
--------
.. code-block:: python
A = hcl.placeholder((3,))
cond = hcl.and_(A[0] > 0, A[1] > 1, A[2] > 2)
"""
ret = args[0]
for i in range(1, len(args)):
ret = _make.And(ret, args[i])
return ret
[docs]def or_(*args):
"""Compute the logic OR between expressions.
If there is only one argument, itself is returned.
Parameters
----------
args : list of Expr
A list of expression to be computed
Returns
-------
Expr
Examples
--------
.. code-block:: python
A = hcl.placeholder((3,))
cond = hcl.or_(A[0] > 0, A[1] > 1, A[2] > 2)
"""
ret = args[0]
for i in range(1, len(args)):
ret = _make.Or(ret, args[i])
return ret
[docs]def if_(cond):
"""Construct an IF branch.
The usage is the same as Python `if` statement. Namely, a single `if`
statement without the `else` branch is allowed. In addition, we cannot
use `else` and `elif` without an `if` statement. Finally, an `else`
statement must be preceded by either an `if` or `elif` statement.
Parameters
----------
cond : Expr
The condition of the `if` statement
Returns
-------
None
Examples
--------
.. code-block:: python
def my_compute(x):
with hcl.if_(A[x] < 3):
# do something
with hcl.elif_(A[x] < 6):
# do something
with hcl.else_():
# do something
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
stage.stmt_stack.append([])
def _exit_cb():
stmt = stage.pop_stmt()
stage.has_break = False
stage.emit(_make.IfThenElse(cond, stmt, None))
return WithScope(None, _exit_cb)
[docs]def else_():
"""Construct an ELSE branch.
Parameters
----------
Returns
-------
None
See Also
--------
if_
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
prev = stage.stmt_stack[-1][-1]
if not isinstance(prev, _stmt.IfThenElse):
raise DSLError("There is no if_ or elif_ in front of the else_ branch")
stage.stmt_stack[-1].pop()
stage.stmt_stack.append([])
def _exit_cb():
stmt = stage.pop_stmt()
stage.has_break = False
stage.emit(stage.replace_else(prev, stmt))
return WithScope(None, _exit_cb)
[docs]def elif_(cond):
"""Construct an ELIF branch.
Parameters
----------
cond : Expr
The condition of the branch
Returns
-------
None
See Also
--------
if_
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
prev = stage.stmt_stack[-1][-1]
if not isinstance(prev, _stmt.IfThenElse):
raise DSLError("There is no if_ or elif_ in front of the elif_ branch")
stage.stmt_stack[-1].pop()
stage.stmt_stack.append([])
def _exit_cb():
stmt = stage.pop_stmt()
stage.has_break = False
stage.emit(stage.replace_else(prev, _make.IfThenElse(cond, stmt, None)))
return WithScope(None, _exit_cb)
[docs]def for_(begin, end, step=1, name="i", dtype="int32", for_type="serial"):
"""Construct a FOR loop.
Create an imperative for loop based on the given bound and step. It is
the same as the following Python code.
.. code-block:: python
for i in range(begin, end, step):
# do something
The bound and step can be negative values. In addition, `begin` is
inclusive while `end` is exclusive.
Parameters
----------
begin : Expr
The starting bound of the loop
end : Expr
The ending bound of the loop
step : Expr, optional
The step of the loop
name : str, optional
The name of the iteration variable
dtype : Type, optional
The data type of the iteration variable
for_type : str, optional
The type of the for loop
Returns
-------
Var
The iteration variable
See Also
--------
break_
Examples
--------
.. code-block:: python
# example 1 - basic usage
with hcl.for_(0, 5) as i:
# i = [0, 1, 2, 3, 4]
# example 2 - negative step
with hcl.for_(5, 0, -1) as i:
# i = [5, 4, 3, 2, 1]
# example 3 - larger step
with hcl.for_(0, 5, 2) as i:
# i = [0, 2, 4]
# example 4 - arbitrary bound
with hcl.for_(-4, -8, -2) as i:
# i = [-4, -6]
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
stage.stmt_stack.append([])
extent = (end - begin) // step
extent = util.CastRemover().mutate(extent)
name = "i"+str(stage.for_ID) if name is None else name
stage.for_ID += 1
iter_var = _IterVar(_make.range_by_min_extent(0, extent), _Var(name, dtype), 0, '')
stage.var_dict[name] = iter_var
stage.axis_list.append(iter_var)
stage.for_level += 1
def _exit_cb():
if for_type == "serial":
for_type_id = 0
elif for_type == "parallel":
for_type_id = 1
elif for_type == "vectorize":
for_type_id = 2
elif for_type == "unroll":
for_type_id = 3
else:
raise ValueError("Unknown for_type")
stmt = _make.AttrStmt(iter_var, "loop_scope", iter_var.var, stage.pop_stmt())
stage.has_break = False
stage.for_level -= 1
stage.emit(_make.For(iter_var.var, 0, extent, for_type_id, 0, stmt))
ret_var = _pass.Simplify(iter_var.var * step + begin)
return WithScope(ret_var, _exit_cb)
[docs]def while_(cond):
"""Construct a WHILE loop.
Parameters
----------
cond : Expr
The condition of the loop
Returns
-------
None
See Also
--------
break_
Examples
--------
.. code-block:: python
with hcl.while_(A[x] > 5):
# do something
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
stage.stmt_stack.append([])
stage.for_level += 1
def _exit_cb():
stmt = stage.pop_stmt()
stage.has_break = False
stage.for_level -= 1
stage.emit(_make.While(cond, stmt))
return WithScope(None, _exit_cb)
[docs]def break_():
"""
Construct a BREAK statement.
This DSL can only be used inside a `while` loop or a `for loop`. Moreover,
it is not allowed to have tracing statements after the `break`.
Parameters
----------
Returns
-------
None
Examples
--------
.. code-block:: python
# example 1 - inside a for loop
with hcl.for_(0, 5) as i:
with hcl.if_(A[i] > 5):
hcl.break_()
# example 2 - inside a while loop
with hcl.while_(A[i] > 5):
with hcl.if_(A[i] > 10):
hcl.break_()
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
if not Stage.get_current().for_level:
raise DSLError("break_ must be used inside a for/while loop")
Stage.get_current().emit(_make.Break())
Stage.get_current().has_break = True
[docs]def def_(shapes, dtypes=None, ret_dtype=None, name=None, arg_names=None):
"""
Define a HeteroCL function from a Python function.
This DSL is used as a Python decorator. The function defined with HeteroCL
is not inlined by default. Users need to provide the shapes of the
arguments, while the data types of the arguments and the returned data
type are optional. This DSL helps make the algorithm more organized and
could potentially reduce the memory usage by reusing the same
functionality. Users can later on use compute primitives to decide whether
to inline these functions or not.
After specifying a Python function to be a HeteroCL function, users can
use the function just like using a Python function. We can also apply
optimization primitives.
Parameters
----------
shapes : list of tuple
The shapes of the arguments
dtypes : list of Type, optional
The data types of the argument
ret_dtype : Type, optional
The data type of the returned value
name : str, optional
The name of the function. By default, it is the same as the Python
function
Returns
-------
None
Examples
--------
.. code-block:: python
# example 1 - no return
A = hcl.placeholder((10,))
B = hcl.placeholder((10,))
x = hcl.placeholder(())
@hcl.def_([A.shape, B.shape, x.shape])
def update_B(A, B, x):
with hcl.for_(0, 10) as i:
B[i] = A[i] + x
# directly call the function
update_B(A, B, x)
# example 2 - with return value
@hcl.def_([(10,), (10,), ()])
def ret_add(A, B, x):
hcl.return_(A[x] + B[x])
# use inside a compute API
A = hcl.placeholder((10,))
B = hcl.placeholder((10,))
C = hcl.compute((10,), lambda x: ret_add(A, B, x))
D = hcl.compute((10,), lambda x: ret_add(A, C, x))
"""
def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=name, arg_names=arg_names):
name = name if name is not None else fmodule.__name__
code = fmodule.__code__
names = code.co_varnames
if arg_names is not None:
names = list(names)
for i in range(len(arg_names)):
names[i] = arg_names[i]
names = tuple(names)
nargs = code.co_argcount
with Stage(name) as s:
# prepare names
new_names = [s.name_with_prefix + "." + name_ for name_ in names]
# prepare dtypes
hcl_dtypes = []
if dtypes is None:
dtypes = []
for name_ in new_names:
dtypes.append(util.get_tvm_dtype(None, name_))
hcl_dtypes.append(util.get_dtype(None, name_))
elif isinstance(dtypes, list):
if len(dtypes) != nargs:
raise APIError("The number of data types does not match the of arguments")
for (name_, dtype_) in zip(new_names, dtypes):
dtypes.append(util.get_tvm_dtype(dtype_, name_))
hcl_dtypes.append(util.get_dtype(dtype_, name_))
dtypes = dtypes[int(len(dtypes)/2):]
else:
dtype = util.get_tvm_dtype(dtypes)
dtypes = []
for name_ in new_names:
dtypes.append(util.get_tvm_dtype(dtype, name_))
ret_dtype = util.get_tvm_dtype(ret_dtype, s.name_with_prefix)
# prepare inputs for IR generation
inputs = []
inputs_tvm = []
arg_shapes, arg_dtypes, arg_tensors = [], [], []
for shape, name_, dtype, htype in zip(shapes, new_names, dtypes, hcl_dtypes):
if shape == ():
var_ = placeholder((), name_, dtype)
inputs.append(var_)
inputs_tvm.append(var_.var)
arg_shapes.append([1])
arg_dtypes.append(dtype)
else: # tensor inputs (new bufs)
placeholder_ = placeholder(shape, name_, htype)
inputs.append(placeholder_)
inputs_tvm.append(placeholder_.buf.data)
arg_shapes.append(list(shape))
arg_dtypes.append(dtype)
arg_tensors.append(placeholder_.op)
s.ret_dtype = ret_dtype
s._module = True
s._inputs = inputs
fmodule(*inputs)
lhs = []
for tensor in s.lhs_tensors:
try:
lhs.append(inputs.index(tensor))
except ValueError:
pass
ret_void = _make.UIntImm("uint1", 0) if s.has_return else _make.UIntImm("uint1", 1)
body = s.pop_stmt()
s.stmt_stack.append([])
s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes, arg_tensors,
body, ret_void, ret_dtype, name, []))
for name_, i in zip(names, inputs):
s.var_dict[name_] = i
s.input_stages.clear()
return Module(shapes, names, name, not s.has_return, lhs, ret_dtype)
return decorator
[docs]def return_(val):
"""Return an expression within a function.
This DSL should be used within a function definition. The return type can
only be an expression.
Parameters
----------
val : Expr
The returned expression
Returns
-------
None
See Also
--------
heterocl.compute, def_
Examples
--------
.. code-block:: python
# example 1 - using with a compute API
A = hcl.placeholder((10,))
def compute_out(x):
with hcl.if_(A[x]>0):
hcl.return_(1)
with hcl.else_():
hcl.return_(0)
B = hcl.compute(A.shape, compute_out)
# example 2 - using with a HeteroCL function
A = hcl.placeholder((10,))
@hcl.def_([A.shape, ()])
def compute_out(A, x):
with hcl.if_(A[x]>0):
hcl.return_(1)
with hcl.else_():
hcl.return_(0)
B = hcl.compute(A.shape, lambda x: compute_out(A, x))
"""
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
dtype = util.get_tvm_dtype(stage.ret_dtype)
stage.emit(_make.Return(_make.Cast(dtype, val)))
stage.has_return = True