Source code for heterocl.tvm.api

"""Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin
from __future__ import absolute_import as _abs

from numbers import Integral as _Integral

from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag

int8 = "int8"
int32 = "int32"
float32 = "float32"
handle = "handle"


[docs]def min_value(dtype): """minimum value of dtype""" return _api_internal._min_value(dtype)
[docs]def max_value(dtype): """maximum value of dtype""" return _api_internal._max_value(dtype)
[docs]def const(value, dtype=None): """construct a constant""" if dtype is None: if isinstance(value, _Integral): dtype = 'int32' else: dtype = 'float32' return _api_internal._const(value, dtype)
[docs]def convert(value): """Convert value to TVM node or function. Parameters ---------- value : python value Returns ------- tvm_val : Node or Function Converted value in TVM """ if isinstance(value, (Function, NodeBase)): return value if callable(value): return _convert_tvm_func(value) return _convert_to_node(value)
[docs]def load_json(json_str): """Load tvm object from json_str. Parameters ---------- json_str : str The json string Returns ------- node : Node The loaded tvm node. """ return _api_internal._load_json(json_str)
[docs]def save_json(node): """Load tvm object as json string. Parameters ---------- node : Node A TVM Node object to be saved. Returns ------- json_str : str Saved json string. """ return _api_internal._save_json(node)
[docs]def any(*args): """Create a new experssion of the union of all conditions in the arguments Parameters ---------- args : list List of symbolic boolean expressions Returns ------- expr: Expr Expression """ if not args: raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] ret = _make.Or(args[0], args[1]) for i in range(2, len(args)): ret = _make.Or(ret, args[i]) return ret
[docs]def all(*args): """Create a new experssion of the intersection of all conditions in the arguments Parameters ---------- args : list List of symbolic boolean expressions Returns ------- expr: Expr Expression """ if not args: raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] ret = _make.And(args[0], args[1]) for i in range(2, len(args)): ret = _make.And(ret, args[i]) return ret
[docs]def decl_buffer(shape, dtype=None, name="buffer", data=None, strides=None, elem_offset=None, scope="", data_alignment=-1, offset_factor=0): """Decleare a new symbolic buffer. Normally buffer is created automatically during lower and build. This is only needed if user want to specify their own buffer layout. See the note below for detailed discussion on usage of buffer. Parameters ---------- shape : tuple of Expr The shape of the buffer. dtype : str, optional The data type of the buffer. name : str, optional The name of the buffer. data : Var, optional The data pointer in the buffer. strides: array of Expr The stride of the buffer. elem_offset: Expr, optional The beginning offset of the array to data. In terms of number of elements of dtype. scope: str, optional The storage scope of the buffer, if not global. If scope equals empty string, it means it is global memory. data_alignment: int, optional The alignment of data pointer in bytes. If -1 is passed, the alignment will be set to TVM's internal default. offset_factor: int, optional The factor of elem_offset field, when set, elem_offset is required to be multiple of offset_factor. If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. Returns ------- buffer : Buffer The created buffer Note ---- Buffer data structure reflects the DLTensor structure in dlpack. While DLTensor data structure is very general, it is usually helpful to create function that only handles specific case of data structure and make compiled function benefit from it. If user pass strides and elem_offset is passed as None when constructing the function, then the function will be specialized for the DLTensor that is compact and aligned. If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic. """ shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape dtype = float32 if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: elem_offset = _api_internal._Var('%s_elem_offset' % name, shape[0].dtype) if data is None: data = _api_internal._Var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor)
def _IterVar(dom, name, iter_type, thread_tag=''): """Internal function to create IterVar Parameters ---------- dom : Range The domain of iteration. name : str The name of iteration variable. iter_type : int The type of iteration. thread_tag : str The thread tag of the iteration variable. Returns ------- iter_var : IterVar The result itervar """ if dom is not None: if isinstance(dom, (list, tuple)): if len(dom) != 2: raise TypeError("need to be list of ranges") dom = Range(dom[0], dom[1]) if not isinstance(dom, _container.Range): raise TypeError("dom need to be Range") name = name if name else 'iter' v = _api_internal._Var(name, "int32") return _api_internal._IterVar(dom, v, iter_type, thread_tag)
[docs]def select(cond, t, f): """Construct a select branch Parameters ---------- cond : Expr The condition t : Expr The result expression if cond is true. f : Expr The result expression if cond is false. Returns ------- node : Node The tvm.expr.Select node """ return _make.Select(convert(cond), convert(t), convert(f))
_init_api("tvm.api")