Source code for heterocl.util

"""Utility functions for HeteroCL"""
from .tvm import make as _make
from .tvm import expr as _expr
from .tvm.expr import Var, Call
from .tvm.api import _IterVar, decl_buffer
from . import types
from . import devices
from . import config
from .scheme import Scheme
from .debug import DTypeError
from .mutator import Mutator

[docs]class VarName(): """A counter for each type of variables. Parameters ---------- name_dict: dict A dictionary whose key is the variable type and whose value is the number of such variable. """ name_dict = {}
[docs]def get_name(var_type, name=None): """Get the name of a given type of variable. If the name is not given, this function automatically generates a name according to the given type of variable. Parameters ---------- var_type: str The type of the variable in string. name: str, optional The name specified by the user. Returns ------- new_name: str The name of the variable. """ if name is not None: return name else: if VarName.name_dict.get(var_type) is None: VarName.name_dict[var_type] = 0 return var_type + "0" else: counter = VarName.name_dict[var_type] + 1 VarName.name_dict[var_type] = counter return var_type + str(counter)
[docs]def get_dtype(dtype, name=None): """Get the data type by default or from a value. We first check if the data type of a variable is specified after the scheduling or the variable is used for the first time. After that, we check whether user specifies the data type or not. Parameters ---------- dtype: Type or str or None The specified data type. name: str, optional The name of the variable that will be given a data type. Returns ------- dtype: str A data type represented in str. """ if Scheme.current is not None: dtype_ = Scheme.current.dtype_dict.get(name) dtype = dtype if dtype_ is None else dtype_ dtype = config.init_dtype if dtype is None else dtype return dtype
[docs]def get_tvm_dtype(dtype, name=None): return types.dtype_to_str(get_dtype(dtype, name))
[docs]def true(): return _make.UIntImm("uint1", 1)
[docs]def make_for(indices, body, level): iter_var = indices[level] if level == len(indices) - 1: body = _make.AttrStmt(iter_var, "loop_scope", iter_var.var, body) return _make.For(iter_var.var, iter_var.dom.min, iter_var.dom.extent, 0, 0, body) else: body = _make.AttrStmt(iter_var, "loop_scope", iter_var.var, make_for(indices, body, level+1)) return _make.For(iter_var.var, iter_var.dom.min, iter_var.dom.extent, 0, 0, body)
# return (index, bit, _)
[docs]def get_index(shape, args, level): if level == len(args) - 1: # the last arg if level == len(shape): # bit-selection return (0, args[level], 1) else: return (args[level], None, shape[level]) else: index = get_index(shape, args, level+1) new_arg = args[level] new_index = _make.Add(index[0], _make.Mul(new_arg, index[2], False), False) new_acc = _make.Mul(index[2], shape[level], False) return (new_index, index[1], new_acc)
[docs]def get_type(dtype): if dtype[0:3] == "int": return "int", int(dtype[3:]) elif dtype[0:4] == "uint": return "uint", int(dtype[4:]) elif dtype[0:5] == "float": return "float", int(dtype[5:]) elif dtype[0:5] == "fixed": strs = dtype[5:].split('_') return "fixed", int(strs[0]), int(strs[1]) elif dtype[0:6] == "ufixed": strs = dtype[6:].split('_') return "ufixed", int(strs[0]), int(strs[1]) else: raise ValueError("Unknown data type: " + dtype)
[docs]class CastRemover(Mutator):
[docs] def mutate_ConstExpr(self, node): return node.value
[docs] def mutate_BinOp(self, binop, node): a = self.mutate(node.a) b = self.mutate(node.b) if isinstance(a, _expr.ConstExpr): a = a.value if isinstance(b, _expr.ConstExpr): b = b.value return binop(a, b, False)
[docs] def mutate_Cast(self, node): return self.mutate(node.value)