Source code for heterocl.schedule

"""A module for compute scheduling."""
#pylint: disable=too-many-instance-attributes, no-self-use, missing-docstring
import networkx as nx
import matplotlib.pyplot as plt
from ordered_set import OrderedSet
from .tvm import tensor
from .tvm import make as _make
from .tvm import stmt as _stmt
from .tvm import expr as _expr
from .tvm import api as tvm_api
from .tvm import _api_internal
from .tvm._api_internal import _ExternOp
from .debug import DSLError, APIError
from . import util
from .devices import Device, DevMediaPair 

[docs]class Schedule(object): """Create a compute schedule. This is a wrapper class for :obj:`tvm.schedule._Schedule`. Parameters ---------- sch : tvm.schedule._Schedule The TVM schedule inputs : list of Tensor Tensors that are the inputs to the schedule """ stage_ops = [] last_stages = OrderedSet([]) def __init__(self, sch, inputs): self.sch = sch self.inputs = inputs self.placement = dict() def __getitem__(self, stage): try: return self.sch[stage._op] except AttributeError: return self.sch[stage.op]
[docs] def dataflow_graph(self, stages=None, level=0, plot=False): """Create a dataflow graph for a given schedule. Parameters ---------- stages : list of Stage, optional The finals stages in the graph. If not specified, draw all the stages level : int, optional The level of stages to draw. If not specified, draw to the inner-most stages plot : bool, optional Whether draw the graph with ``matplotlib`` or not Returns ------- networkx.DiGraph A directional graph that describes the dataflow """ graph = nx.DiGraph() level_count = [0] op_map = dict() pos = {} def gen_graph(stage, y): names = [] for input_stage in stage.input_stages: if len(level_count) == y: level_count.append(0) names += gen_graph(input_stage, y+1) name_with_prefix = stage.name_with_prefix op_map[name_with_prefix] = self.sch[stage._op] if len(name_with_prefix.split('.')) <= level or level == 0: for name in names: graph.add_edge(name, name_with_prefix) pos[name] = (level_count[y], y) level_count[y] += 1 return [name_with_prefix] return names if stages is None: stages = Schedule.last_stages else: if not isinstance(stages, (tuple, list)): stages = [stages] x = 0 for stage in stages: gen_graph(stage, 1) pos[stage.name_with_prefix] = (x, 0) x += 1 return graph, op_map
[docs] def subgraph(self, inputs, outputs): assert len(inputs) > 0, "empty inputs" assert len(outputs) > 0, "empty outputs" graph, op_map = self.dataflow_graph() # check availability inputs = [ _.name.replace(".new", "") for _ in inputs ] outputs = [ _.name.replace(".new", "") for _ in outputs ] # from root to parents stack = outputs subgraph = list() while len(stack) > 0: op = stack.pop() if op in subgraph: continue subgraph.append(op) if op not in graph.nodes: op = "_top." + op assert op in graph.nodes, \ "cannot find node " + op + " in " + str(graph.nodes) for _ in graph.predecessors(op): if not op in inputs: stack.append(_) return subgraph
[docs] def reuse_at(self, target, parent, axis, name=None): """Create a reuse buffer reusing the output of current stage This returns a new tensor representing the reuse buffer. A stage is also built correspondingly. The new stage will be a sub-stage of the parent stage under the specified axis. Thus, the axis must be inside the axis list of the parent stage. Parameters ---------- target : Tensor The tensor whose values will be reused parent : Stage The stage that reuses the output of the current stage axis : IterVar The axis that generates the reuse values name : string, optional The name of the reuse buffer Returns ------- Tensor """ try: target = target.tensor except AttributeError: try: target = target._op except AttributeError: pass if name is None: name = target.name + ".reuse" return self.sch.reuse_at(target, parent, axis, name)
[docs] def join(self, srcs, dest=None): """ join multiple tensors to single dest """ assert len(srcs) > 0, "joined tensors should be " + \ "collectde from more than one srcs" # create channels and collector stage if dest is not None: if isinstance(dest, tuple): dest, target = dest dest = self[dest] elif isinstance(dest, Stage): target = dest._op elif isinstance(dest, tuple): src, target = dest else: # target tensor target = dest.tensor else: target = dest for src in srcs: if isinstance(src, tuple): src, tensor = src assert tensor == target, + \ "inconsistent tensor joining" self.sch.join(target, dest, self[src])
[docs] def fork(self, tensor, dests, axis=0): """ fork tensor to multiple dests """ assert len(dests) > 0, "forked tensor should be " + \ "broadcast to more than one dest" # dest as tvm stages for dest in dests: self.to(tensor, self[dest])
[docs] def to(self, tensors, dst, src=None, axis=0, stream_type=_expr.StreamExpr.FIFO, depth=1, name=None): """Stream a list of Tensors to dst devices Parameters ---------- tensors : list of Tensor The tensors to be moved dst : device or stage The destination of data movement src : device or stage The source of data movement axis : axis index Move axis-th loop body to xcel scope depth : channel depth The streaming channel depth """ if stream_type > 2: raise APIError("Invalid channel type") rets = [] if not isinstance(tensors, list): tensors = [tensors] for tensor in tensors: try: if isinstance(tensor, Stage): target = tensor._op # unpack tuple of src stage and tensor elif isinstance(tensor, tuple): src, target = tensor # from hcl stage to tvm stage src = self.__getitem__(src) else: # target tensor target = tensor.tensor except (AttributeError, ValueError): target = tensor # convert hcl stage try: dst = self[dst] except: pass if src is None: # move to device if isinstance(dst, Device) or \ isinstance(dst, DevMediaPair): if axis == 0: self.placement[target] = dst else: assert isinstance(tensor, Stage) target = self[tensor] else: # inter-stage src = self[tensor] # target can be stage or tensor ret = self.sch.to(target, dst, src, axis, stream_type, depth) rets.append(ret) if len(rets) == 1: return rets[0] else: return rets
[docs] def partition(self, target, partition_type=_stmt.Partition.Complete, dim=0, factor=0): """Partition a Tensor into smaller Tensors or even registers Users can specify the partition type, which includes Complete, Block, and Cyclic. The default type is Complete, which means we completely partition the specified dimension. If Block is specified, the tensor is partitioned into N blocks with equal size. The number N is specified by the factor. Otherwise, if Cyclic is specified, the elements of the tensor is partition in a cyclic manner. For example, if the factor is three, the 1st element will be assigned to the 1st partitioned tensor; the 2nd element will be assigned to the 2nd one; and so on. Finally, if Complete is specified, the factor will be ignored. If `dim` is set to 0, it means we partition all dimensions. Parameters ---------- target : Tensor The tensor to be partitioned partition_type : {Complete, Block, Cyclic}, optional The partition type dim : int, optional The dimension to be partitioned factor : int, optional The partition factor """ if partition_type > 2: raise APIError("Invalid partition type") if dim < 0: raise APIError("Invalid dimension") if factor < 0: raise APIError("Invalid factor") try: target = target.tensor except (AttributeError, ValueError): try: target = target._op except AttributeError: pass return self.sch.partition(target, partition_type, dim, factor)
[docs] def reshape(self, target, shape): """Reshape a Tensor to a specified new shape Parameters ---------- target : Tensor The tensor to be reshaped shape : tuple of int The new shape of the tensor """ try: target = target.tensor except (AttributeError, ValueError): try: target = target._op except AttributeError: pass _api_internal._ScheduleReshape(self.sch, target, shape)
[docs]class Stage(object): """Create a stage in the algorithm. Stage is needed when an imperative DSL block is not used within any other compute APIs. We can further use the created stage to help us schedule the imperative components within it. It can also be used to describe a higher level of computation hierarchy. For example, we can wrap several compute APIs into a single stage. Parameters ---------- name : str, optional The name of the Stage Attributes ---------- stmt_stack : list of list of Stmt Store all statements. There are two levels. The outer level is for different scopes of statement. The inner level is for different statements var_dict : dict(str, _Var) A dictionary whose key is the name of the variable and the value is the variable itself. This enables users to access a variable inside a Stage via a Python attribute axis_list : list of IterVar A list of axes appeared in this Stage has_break : bool Set to `True` if there is a `break` statement within the stage has_return : bool Set to `True` if there is a `return` statement within the stage ret_dtype : Type The returned data type. Only exists for `heterocl.compute` for_level : int The level of a loop nest where the current statement is. for_id : int An index used to label the unnamed axes input_stages : set of Stage A set of stages that are the input to the Stage lhs_tensors : set of Tensor The tensors that are updated at the left-hand side last_substages : set of Stage A set of sub-stages that are last used in the current stage name_with_prefix : str The full name of the stage. This is used when two stages at different levels share the same name Examples -------- .. code-block:: python A = hcl.placeholder((10,)) with hcl.Stage(): A[0] = 5 with hcl.for_(1, 10) as i: A[i] = A[i-1] * 2 """ _current = [] """Store all living `Stage`. The newest is at the end.""" def __init__(self, name=None, dtype=None, shape=()): # Attributes related to a single stage self.name = util.get_name("stage", name) self.stmt_stack = [[]] self.var_dict = {} self.axis_list = [] self.has_break = False self.has_return = False self.ret_dtype = None self.for_level = 0 self.for_ID = 0 self.substages = [] # Attributes for cross-stage relation self.input_stages = set([]) self.lhs_tensors = set([]) self.last_substages = set([]) self.name_with_prefix = self.name if Stage.get_len() == 0 \ else Stage.get_current().name_with_prefix + "." + self.name # Private attributes for building a stage self._op = None self._hcl_dtype = util.get_dtype(dtype, self.name_with_prefix) self._dtype = util.get_tvm_dtype(dtype, self.name_with_prefix) self._buf = tvm_api.decl_buffer(shape, self._dtype, self.name) self._shape = self._buf.shape def __enter__(self): Stage._current.append(self) return self def __exit__(self, ptype, value, trace): # update input_stages: the union of the last substages and original input stages # collected in the stage self.input_stages = self.last_substages.union(self.input_stages) # create the output operation input_ops = [i._op for i in self.input_stages] input_bufs = [i._buf for i in self.input_stages] output_bufs = [self._buf] body = self.pop_stmt() Stage._current.pop() op = _ExternOp(self.name, "", self.axis_list, input_ops, input_bufs, output_bufs, body) self._op = op.output(0) # update last_update stages # if this stage is a substage of other stages if Stage._current: superstage = Stage._current[-1] # add attribute statement for later stage insertion superstage.emit( lambda x: _make.AttrStmt(self._buf, "attach_scope", _make.StringImm(superstage.name), x)) # update the input stages of the superstage: # input_stages = original input stages + current input stages - last substages superstage.input_stages = superstage.input_stages.union(self.input_stages) superstage.input_stages.difference_update(superstage.last_substages) # update the last substages of the superstage: # last_substages = original substages + current stage - inputs of current stage superstage.last_substages.add(self) superstage.last_substages.difference_update(self.input_stages) # update lhs_tensors: # lhs_tensors = original tensors + lhs tensors of current stage superstage.lhs_tensors.update(self.lhs_tensors) # update var_dict superstage.var_dict[self.name] = self # update prefix self.name_with_prefix = superstage.name_with_prefix + "." + self.name # update superstage's substages superstage.substages.append(self) # Otherwise update the list of stages globally else: Schedule.stage_ops.append(self) Schedule.last_stages.add(self) Schedule.last_stages -= self.input_stages def __repr__(self): return self.name def __getattr__(self, name): try: if name in self.var_dict: return self.var_dict[name] else: # return stage and target tensor op for tensor in self.lhs_tensors: if tensor.name == name: return (self, tensor._tensor) # check tensors in input stages for stage in self.input_stages: if stage.name == name: return (self, stage._op) # check tensors in input_stage.lhs for stage in self.input_stages: lhs = stage.lhs_tensors for tensor in lhs: if tensor.name == name: return (self, tensor._tensor) raise ValueError("Member " + name + \ " not found in " + str(self.lhs_tensors) + " or " + \ str(self.input_stages)) except KeyError: raise ValueError("Uknown member " + name + " of " + self.name)
[docs] def emit(self, stmt): """Insert statements to the current stage.""" if self.has_break: raise DSLError("Cannot write statements after break") self.stmt_stack[-1].append(stmt)
[docs] def replace_else(self, if_stmt, else_stmt): """Add an ELSE or ELIF branch to an existing IF or ELIF branch.""" assert isinstance(if_stmt, _stmt.IfThenElse), "Wrong if statement" if isinstance(if_stmt.else_case, _stmt.IfThenElse): return _make.IfThenElse(if_stmt.condition, if_stmt.then_case, self.replace_else(if_stmt.else_case, else_stmt)) return _make.IfThenElse(if_stmt.condition, if_stmt.then_case, else_stmt)
[docs] def pop_stmt(self): """Create a statement from the statements within current stage.""" stmts = self.stmt_stack.pop() if not stmts or callable(stmts[-1]): stmts.append(_make.Evaluate(0)) stmt = stmts[-1] for s in reversed(stmts[:-1]): if callable(s): stmt = s(stmt) else: assert isinstance(s, _stmt.Stmt) stmt = _make.Block(s, stmt) return stmt
[docs] @staticmethod def get_current(): """Get the current stage.""" return Stage._current[-1]
[docs] @staticmethod def get_len(): """Get the level of stages.""" return len(Stage._current)
@property def axis(self): """Get the axes of the stage.""" return self._op.op.axis