Source code for heterocl.tvm.tag
"""Tag class for TVM operators."""
from ._ffi.base import _LIB_NAME
try:
from decorator import decorate
except ImportError as err_msg:
# Allow decorator to be missing in runtime
if _LIB_NAME != "libhcl_runtime.so":
raise err_msg
[docs]class TagScope(object):
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
"""
current = None
def __init__(self, tag):
self._old_scope = None
self.tag = tag
def __enter__(self):
if TagScope.current is not None:
raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope.current
TagScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
TagScope.current = self._old_scope
def __call__(self, fdecl):
def tagged_fdecl(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate(fdecl, tagged_fdecl)
[docs]def tag_scope(tag):
"""The operator tag scope.
Parameters
----------
tag: str
The tag name.
Returns
-------
tag_scope: TagScope
The tag scope object, which can be used as decorator or
context manger.
Example
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
with tvm.tag_scope(tag='matmul'):
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))
# or use tag_scope as decorator
@tvm.tag_scope(tag="conv")
def compute_relu(data):
return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
"""
return TagScope(tag)