"""HeteroCL tensors and scalars."""
#pylint: disable=missing-docstring, too-many-instance-attributes
from .tvm import make as _make
from .tvm import expr as _expr
from .tvm.api import decl_buffer
from .tvm._ffi.node import NodeGeneric
from .debug import TensorError
from .schedule import Stage
from . import util
from . import debug
from . import types
[docs]class Scalar(NodeGeneric, _expr.ExprOp):
"""A non-mutable scalar.
This should be used by `heterocl.placeholder` only. Valid usages of
accessing a scalar include direct access and bit operations.
Parameters
----------
var : Var
A TVM variable
Attributes
----------
var : Var
The wrapped TVM variable
dtype : Type
The data type of the scalar
See Also
--------
heterocl.placeholder
Examples
--------
.. code-block:: python
# use () to specify it is a non-mutable scalar
a = hcl.placeholder((), "a")
# direct access
b = a + 5
# bit operations
c = a[2] # the third bit of a
d = a[3:5] # get a slice of a
"""
def __init__(self, var):
self.var = var
def __getitem__(self, indices):
if isinstance(indices, slice):
return _make.GetSlice(self.var, indices.start, indices.stop)
elif isinstance(indices, (int, _expr.Expr)):
return _make.GetBit(self.var, indices)
else:
raise TensorError("Invalid index")
@property
def name(self):
return self.var.name
@property
def dtype(self):
return self.var.dtype
def same_as(self, var):
if isinstance(var, Scalar):
return self.var.same_as(var.var)
elif isinstance(var, _expr.Expr):
return self.var.same_as(var)
return False
[docs] def asnode(self):
return self.var
[docs]class TensorSlice(NodeGeneric, _expr.ExprOp):
"""A helper class for tensor operations.
Valid tensor accesses include: 1. getting an element from a tensor 2. bit
operations on the element. We **do not** support operations on a slice of
tensor.
Parameters
----------
tensor : Tensor
The target tensor
indices : int or tuple of int
The indices to access the tensor
Attributes
----------
tensor : Tensor
The target tensor
indices : int or tuple of int
The indices to access the tensor
dtype : Type
The data type of the tensor
Examples
--------
.. code-block:: python
A = hcl.placeholder((10,), "A")
# get a single element
a = A[5]
# bit operations on a single element
b = A[5][2]
c = A[5][3:7]
# not allowed: A[5:7]
"""
def __init__(self, tensor, indices, dtype=None):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor
self.indices = indices
self._dtype = dtype if dtype is not None else self.tensor.dtype
def __getitem__(self, indices):
if not isinstance(indices, tuple):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)
def __setitem__(self, indices, expr):
if not isinstance(indices, tuple):
indices = (indices,)
indices = self.indices + indices
index, bit, _ = util.get_index(self.tensor.shape, indices, 0)
if not Stage.get_len():
raise TensorError("Cannot set tensor elements without compute APIs")
builder = Stage.get_current()
if bit is None:
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self._dtype, expr),
index))
elif isinstance(bit, slice):
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
# special handle for struct: we need to make sure the bitwidths
# are the same before and after bitcast
if (isinstance(self.tensor.type, types.Struct)
and util.get_type(self._dtype) != "uint"):
ty = "uint" + str(util.get_type(self._dtype)[1])
expr = _make.Call(ty, "bitcast",
[expr], _expr.Call.PureIntrinsic, None, 0)
expr = _make.SetSlice(load, expr, bit.start, bit.stop)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
index))
else:
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
expr = _make.SetBit(load, expr, bit)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self._dtype, expr),
index))
def __getattr__(self, key):
hcl_dtype = self.tensor.hcl_dtype
if not isinstance(hcl_dtype, types.Struct):
raise TensorError(
"Cannot access attribute if type is not struct")
start = 0
end = 0
dtype = None
for dkey, dval in hcl_dtype.dtype_dict.items():
if dkey == key:
end = start + dval.bits
dtype = types.dtype_to_str(dval)
break
else:
start += dval.bits
if dtype is None:
raise DTypeError("Field " + key
+ " is not in struct " + str(hcl_dtype))
indices = (slice(end, start),)
return TensorSlice(self.tensor, self.indices + indices, dtype)
def __setattr__(self, key, expr):
if key in ("tensor", "indices", "_dtype"):
super().__setattr__(key, expr)
else:
hcl_dtype = self.tensor.hcl_dtype
if not isinstance(hcl_dtype, types.Struct):
raise TensorError(
"Cannot access attribute if type is not struct")
start = 0
end = 0
for dkey, dval in hcl_dtype.dtype_dict.items():
if dkey == key:
end = start + dval.bits
self._dtype = types.dtype_to_str(dval)
break
else:
start += dval.bits
if start == end:
raise DTypeError("Field " + key
+ " is not in struct " + str(hcl_dtype))
indices = (slice(end, start),)
self.__setitem__(indices, expr)
@property
def dtype(self):
return self.tensor.dtype
[docs] def asnode(self):
if len(self.indices) < len(self.tensor.shape):
raise TensorError("Accessing a slice of tensor is not allowed")
index, bit, _ = util.get_index(self.tensor.shape, self.indices, 0)
if bit is None:
return _make.Load(self._dtype, self.tensor.buf.data, index)
elif isinstance(bit, slice):
load = _make.GetSlice(_make.Load(self.tensor.dtype,
self.tensor.buf.data, index),
bit.start,
bit.stop)
if self.tensor.dtype != self._dtype:
bw_from = types.get_bitwidth(self.tensor.dtype)
bw_to = types.get_bitwidth(self._dtype)
if bw_from != bw_to:
ty = util.get_type(self.tensor.dtype)[0] + str(bw_to)
load = _make.Cast(ty, load)
return _make.Call(self._dtype, "bitcast",
[load], _expr.Call.PureIntrinsic, None, 0)
else:
return load
return _make.GetBit(_make.Load(self._dtype,
self.tensor.buf.data,
index), bit)
[docs]class Tensor(NodeGeneric, _expr.ExprOp):
"""A HeteroCL tensor.
This is a wrapper for a TVM tensor. It should be generated from HeteroCL
compute APIs.
Parameters
----------
shape : tuple of int
The shape of the tensor
dtype : Type, optional
The data type of the tensor
name : str, optional
The name of the tensor
buf : Buffer, optional
The TVM buffer of the tensor
Attributes
----------
dtype : Type
The data type of the tensor
name : str
The name of the tensor
var_dict : dict(str, Var)
A dictionary that maps between a name and a variable
first_update : Stage
The first stage that updates the tensor
last_update : Stage
The last stage that updates the tensor
tensor : Operation
The TVM tensor
buf : Buffer
The TVM buffer
type : Type
The data type in HeteroCL format
op : Stmt
The operation statement
axis : list of IterVar
A list of axes of the tensor
v : Expr
Syntactic sugar to access the element of an single-element tensor
See Also
--------
heterocl.placeholder, heterocl.compute
"""
__hash__ = NodeGeneric.__hash__
def __init__(self, shape, dtype="int32", name="tensor", buf=None):
self._tensor = None
self._buf = buf
self.hcl_dtype = dtype
self.dtype = types.dtype_to_str(dtype)
self.shape = shape
self.name = name
self.var_dict = {}
self.first_update = None
self.last_update = None
if buf is None:
self._buf = decl_buffer(shape, self.dtype, name)
def __repr__(self):
return "Tensor('" + self.name + "', " + str(self.shape) + ", " + str(self.dtype) + ")"
def __getitem__(self, indices):
indices = util.CastRemover().mutate(indices)
if Stage.get_len():
Stage.get_current().input_stages.add(self.last_update)
if not isinstance(indices, tuple):
indices = (indices,)
return TensorSlice(self, indices)
def __setitem__(self, indices, expr):
indices = util.CastRemover().mutate(indices)
Stage.get_current().input_stages.add(self.last_update)
Stage.get_current().lhs_tensors.add(self)
if not isinstance(indices, tuple):
indices = (indices,)
indices = util.CastRemover().mutate(indices)
if len(indices) < len(self.shape):
raise TensorError("Accessing a slice of tensor is not allowed")
else:
index, bit, _ = util.get_index(self.shape, indices, 0)
if not Stage.get_len():
raise TensorError("Cannot set tensor elements without compute APIs")
builder = Stage.get_current()
if bit is None:
builder.emit(_make.Store(self.buf.data,
_make.Cast(self.dtype, expr),
index))
elif isinstance(bit, slice):
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
expr = _make.SetSlice(load, expr, bit.start, bit.stop)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
index))
else:
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
expr = _make.SetBit(load, expr, bit)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
index))
@property
def tensor(self):
return self._tensor
@property
def buf(self):
return self._buf
@property
def type(self):
return self.hcl_dtype
@property
def op(self):
return self.tensor.op
@property
def axis(self):
return self.tensor.op.axis
@property
def v(self):
if len(self.shape) == 1 and self.shape[0] == 1:
return self.__getitem__(0)
else:
raise debug.APIError(".v can only be used on mutable scalars")
@buf.setter
def buf(self, buf):
"""Set the TVM buffer.
Parameters
----------
buf : Buffer
"""
self._buf = buf
Tensor.tensor_map[self._tensor] = buf
@tensor.setter
def tensor(self, tensor):
"""Set the TVM tensor.
Parameters
----------
tensor : Tensor
"""
self._tensor = tensor
@v.setter
def v(self, value):
"""A syntactic sugar for setting the value of a single-element tensor.
This is the same as using `a[0]=value`, where a is a single-element tensor.
Parameters
----------
value : Expr
The value to be set
"""
self.__setitem__(0, value)
[docs] def asnode(self):
if len(self.shape) == 1 and self.shape[0] == 1:
return TensorSlice(self, 0).asnode()
else:
raise ValueError("Cannot perform expression on Tensor")
def same_as(self, tensor):
if isinstance(tensor, Tensor):
return self._tensor.same_as(tensor.tensor)
return False