Source code for heterocl.tvm.target

from __future__ import absolute_import

import warnings
from ._ffi.base import _LIB_NAME

try:
    from decorator import decorate
except ImportError as err_msg:
    # Allow decorator to be missing in runtime
    if _LIB_NAME != "libhcl_runtime.so":
        raise err_msg

FPGA_TARGETS = ['merlinc', 'soda', 'soda_xhls', 'vhls', 'ihls', 'vhls_csim', 
                'opencl', 'xocl', 'aocl', 'rv64_ppac']

def _merge_opts(opts, new_opts):
    """Helper function to merge options"""
    if isinstance(new_opts, str):
        new_opts = new_opts.split()
    if new_opts:
        opt_set = set(opts)
        new_opts = [opt for opt in new_opts if opt not in opt_set]
        return opts + new_opts
    return opts


[docs]class Target(object): """Target device information, use through TVM API. Parameters ---------- target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "opengl", "ext_dev", "rv64_ppac"} The major target name. {"merlinc", "soda", "soda_xhls", "vhls"} The HeteroCL specific target name for FPGAs. options : list of str, optional Additional arguments appended to the target. Note ---- Do not use class constructor, you can create target using the following functions - :any:`tvm.target.create` create target from string - :any:`tvm.target.rasp` create raspberry pi target - :any:`tvm.target.cuda` create CUDA target - :any:`tvm.target.rocm` create ROCM target - :any:`tvm.target.mali` create Mali target """ current = None def __init__(self, target_name, options=None): self.target_name = target_name self.options = _merge_opts([], options) self.device_name = "" self.libs = [] # Parse device option for item in self.options: if item.startswith("-libs="): libs = item.split("=")[1] self.libs += libs.split(",") elif item.startswith("-device="): self.device_name = item.split("=")[1] # Target query searches device name first if self.device_name: self.keys = (self.device_name,) else: self.keys = () # Target configuration handling self.thread_warp_size = 1 if target_name in ("llvm", ): self.keys += ("cpu",) elif target_name in ("cuda", "nvptx"): self.keys += ("cuda", "gpu") self.max_num_threads = 512 self.thread_warp_size = 32 elif target_name in ("rocm", "opencl"): # For now assume rocm schedule for opencl self.keys += ("rocm", "gpu") self.max_num_threads = 256 elif target_name in ("metal", "vulkan"): self.keys += (target_name, "gpu",) self.max_num_threads = 256 elif target_name in ("opengl",): self.keys += ("opengl",) elif target_name in ("stackvm", "ext_dev"): # Do not now class for stackvm or ext_dev pass elif target_name in FPGA_TARGETS: self.keys += ("fpga",) else: raise ValueError("Unknown target name %s" % target_name) def __str__(self): return " ".join([self.target_name] + self.options) def __repr__(self): return self.__str__() def __enter__(self): self._old_target = Target.current if self._old_target is not None and str(self) != str(self._old_target): warnings.warn( "Override target '%s' with new target scope '%s'" % ( self._old_target, self)) Target.current = self return self def __exit__(self, ptype, value, trace): Target.current = self._old_target
[docs]def generic_func(fdefault): """Wrap a target generic function. Generic function allows registeration of further functions that can be dispatched on current target context. If no registered dispatch is matched, the fdefault will be called. Parameters ---------- fdefault : function The default function. Returns ------- fgeneric : function A wrapped generic function. Example ------- .. code-block:: python import tvm # wrap function as target generic @tvm.target.generic_func def my_func(a): return a + 1 # register specialization of my_func under target cuda @my_func.register("cuda") def my_func_cuda(a): return a + 2 # displays 3, because my_func is called print(my_func(2)) # displays 4, because my_func_cuda is called with tvm.target.cuda(): print(my_func(2)) """ dispatch_dict = {} func_name = fdefault.__name__ def register(key, func=None, override=False): """Register function to be the dispatch function. Parameters ---------- key : str or list of str The key to be registered. func : function The function to be registered. override : bool Whether override existing registeration. Returns ------- The register function is necessary. """ def _do_reg(myf): key_list = [key] if isinstance(key, str) else key for k in key_list: if k in dispatch_dict and not override: raise ValueError( "Key is already registered for %s" % func_name) dispatch_dict[k] = myf return myf if func: return _do_reg(func) return _do_reg def dispatch_func(func, *args, **kwargs): """The wrapped dispath function""" target = current_target() if target is None: return func(*args, **kwargs) for k in target.keys: if k in dispatch_dict: return dispatch_dict[k](*args, **kwargs) return func(*args, **kwargs) fdecorate = decorate(fdefault, dispatch_func) fdecorate.register = register return fdecorate
[docs]def cuda(options=None): """Returns a cuda target. Parameters ---------- options : list of str Additional options """ return Target("cuda", options)
[docs]def rocm(options=None): """Returns a ROCM target. Parameters ---------- options : list of str Additional options """ return Target("rocm", options)
[docs]def rasp(options=None): """Returns a rasp target. Parameters ---------- options : list of str Additional options """ opts = ["-device=rasp", "-mtriple=armv7l-none-linux-gnueabihf", "-mcpu=cortex-a53", "-mattr=+neon"] opts = _merge_opts(opts, options) return Target("llvm", opts)
[docs]def mali(options=None): """Returns a ARM Mali GPU target. Parameters ---------- options : list of str Additional options """ opts = ["-device=mali"] opts = _merge_opts(opts, options) return Target("opencl", opts)
[docs]def opengl(options=None): """Returns a OpenGL target. Parameters ---------- options : list of str Additional options """ return Target("opengl", options)
[docs]def create(target_str): """Get a target given target string. Parameters ---------- target_str : str The target string. Returns ------- target : Target The target object Note ---- See the note on :any:`tvm.target` on target string format. """ if isinstance(target_str, Target): return target_str if not isinstance(target_str, str): raise ValueError("target_str has to be string type") arr = target_str.split() # Parse device option device_name = "" for item in arr[1:]: if item.startswith("-device="): device_name = item.split("=")[1] if device_name == "rasp": return rasp(arr[1:]) if device_name == "mali": return mali(arr[1:]) return Target(arr[0], arr[1:])
[docs]def current_target(allow_none=True): """Returns the current target. Parameters ---------- allow_none : bool Whether allow the current target to be none Raises ------ ValueError if current target is not set. """ if Target.current: return Target.current if not allow_none: raise RuntimeError( "Requires a current target in generic function, but it is not set. " "Please set it using `with TargetObject:`") return Target.current