Source code for heterocl.tvm.tensor

"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
from . import _api_internal
from . import make as _make
from . import expr as _expr

itervar_cls = None

@register_node("Tensor")
class _Tensor(NodeBase, _expr.ExprOp):
    """Tensor object, to construct, see function.Tensor"""

    @property
    def ndim(self):
        """Dimension of the tensor."""
        return len(self.shape)

    @property
    def axis(self):
        """Axis of the tensor."""
        return self.__getattr__("axis")

    @property
    def op(self):
        """The corressponding :any:`Operation`."""
        return self.__getattr__("op")

    @property
    def value_index(self):
        """The output value index the tensor corressponds to."""
        return self.__getattr__("value_index")

    @property
    def shape(self):
        """The output shape of the tensor."""
        return self.__getattr__("shape")

    @property
    def name(self):
        op = self.op
        if op.num_outputs == 1:
            return op.name
        return "%s.v%d" % (op.name, self.value_index)


[docs]class Operation(NodeBase): """Represent an operation that generate a tensor"""
[docs] def output(self, index): """Get the index-th output of the operation Parameters ---------- index : int The index size. Returns ------- out : Tensor The i-th output. """ return _api_internal._OpGetOutput(self, index)
@property def num_outputs(self): """Number of outputs of this op.""" return _api_internal._OpNumOutputs(self) @property def input_tensors(self): """List of input tensors to this op.""" return _api_internal._OpInputTensors(self)
[docs]@register_node class PlaceholderOp(Operation): """Placeholder operation.""" pass
[docs]@register_node class ComputeOp(Operation): """Compute operation.""" @property def axis(self): """Represent axis of IterVar, only defined when it is a ComputeOp""" return self.__getattr__("axis") @property def reduce_axis(self): """Represent axis of reductions, only defined when it is a ComputeOp""" return self.__getattr__("reduce_axis")
[docs]@register_node class ExternOp(Operation): """Extern operation.""" pass