from .tvm import expr as _expr
from .tvm import stmt as _stmt
from .tvm import make as _make
from .tvm.api import convert
[docs]class Mutator(object):
[docs] def mutate(self, node):
if isinstance(node, _expr.Expr):
if isinstance(node, _expr.ConstExpr):
return self.mutate_ConstExpr(node)
elif isinstance(node, _expr.BinaryOpExpr):
if isinstance(node, _expr.Add):
return self.mutate_Add(node)
elif isinstance(node, _expr.Sub):
return self.mutate_Sub(node)
elif isinstance(node, _expr.Mul):
return self.mutate_Mul(node)
elif isinstance(node, _expr.Div):
return self.mutate_Div(node)
elif isinstance(node, _expr.Mod):
return self.mutate_Mod(node)
elif isinstance(node, _expr.Min):
return self.mutate_Min(node)
elif isinstance(node, _expr.Max):
return self.mutate_Max(node)
else:
return node
elif isinstance(node, _expr.CmpExpr):
if isinstance(node, _expr.EQ):
return self.mutate_EQ(node)
elif isinstance(node, _expr.NE):
return self.mutate_NE(node)
elif isinstance(node, _expr.LT):
return self.mutate_LT(node)
elif isinstance(node, _expr.LE):
return self.mutate_LE(node)
elif isinstance(node, _expr.GT):
return self.mutate_GT(node)
elif isinstance(node, _expr.GE):
return self.mutate_GE(node)
else:
return node
elif isinstance(node, _expr.LogicalExpr):
if isinstance(node, _expr.And):
return self.mutate_And(node)
elif isinstance(node, _expr.Or):
return self.mutate_Or(node)
elif isinstance(node, _expr.Not):
return self.mutate_Not(node)
else:
return node
else:
if isinstance(node, _expr.Var):
return self.mutate_Var(node)
elif isinstance(node, _expr.Cast):
return self.mutate_Cast(node)
elif isinstance(node, _expr.Select):
return self.mutate_Select(node)
elif isinstance(node, _expr.Load):
return self.mutate_Load(node)
elif isinstance(node, _expr.Ramp):
return self.mutate_Ramp(node)
elif isinstance(node, _expr.Broadcast):
return self.mutate_Broadcast(node)
elif isinstance(node, _expr.Call):
return self.mutate_Call(node)
elif isinstance(node, _expr.Let):
return self.mutate_Let(node)
elif isinstance(node, _expr.GetBit):
return self.mutate_GetBit(node)
elif isinstance(node, _expr.GetSlice):
return self.mutate_GetSlice(node)
elif isinstance(node, _expr.SetBit):
return self.mutate_SetBit(node)
elif isinstance(node, _expr.SetSlice):
return self.mutate_SetSlice(node)
elif isinstance(node, _expr.KernelExpr):
return self.mutate_KernelExpr(node)
elif isinstance(node, _expr.StreamExpr):
return self.mutate_StreamExpr(node)
else:
return node
elif isinstance(node, _stmt.Stmt):
if isinstance(node, _stmt.LetStmt):
return self.mutate_LetStmt(node)
elif isinstance(node, _stmt.AssertStmt):
return self.mutate_AssertStmt(node)
elif isinstance(node, _stmt.ProducerConsumer):
return self.mutate_ProducerConsumer(node)
elif isinstance(node, _stmt.ExternModule):
return self.mutate_ExternModule(node)
elif isinstance(node, _stmt.For):
return self.mutate_For(node)
elif isinstance(node, _stmt.Store):
return self.mutate_Store(node)
elif isinstance(node, _stmt.Allocate):
return self.mutate_Allocate(node)
elif isinstance(node, _stmt.AttrStmt):
return self.mutate_AttrStmt(node)
elif isinstance(node, _stmt.Free):
return self.mutate_Free(node)
elif isinstance(node, _stmt.Block):
return self.mutate_Block(node)
elif isinstance(node, _stmt.IfThenElse):
return self.mutate_IfThenElse(node)
elif isinstance(node, _stmt.Evaluate):
return self.mutate_Evaluate(node)
elif isinstance(node, _stmt.KernelDef):
return self.mutate_KernelDef(node)
elif isinstance(node, _stmt.KernelStmt):
return self.mutate_KernelStmt(node)
elif isinstance(node, _stmt.Return):
return self.mutate_Return(node)
elif isinstance(node, _stmt.Break):
return self.mutate_Break(node)
elif isinstance(node, _stmt.While):
return self.mutate_While(node)
elif isinstance(node, _stmt.StreamStmt):
return self.mutate_StreamStmt(node)
else:
return node
elif isinstance(node, tuple):
return self.mutate_Tuple(node)
elif isinstance(node, list):
return self.mutate_List(node)
elif callable(node):
return self.mutate_Function(node)
else:
return node
[docs] def mutate_ConstExpr(self, node):
return node
[docs] def mutate_BinOp(self, binop, node):
a = self.mutate(node.a)
b = self.mutate(node.b)
return binop(a, b)
[docs] def mutate_Add(self, node):
return self.mutate_BinOp(_make.Add, node)
[docs] def mutate_Sub(self, node):
return self.mutate_BinOp(_make.Sub, node)
[docs] def mutate_Mul(self, node):
return self.mutate_BinOp(_make.Mul, node)
[docs] def mutate_Div(self, node):
return self.mutate_BinOp(_make.Div, node)
[docs] def mutate_Mod(self, node):
return self.mutate_BinOp(_make.Mod, node)
[docs] def mutate_Min(self, node):
return self.mutate_BinOp(_make.Min, node)
[docs] def mutate_Max(self, node):
return self.mutate_BinOp(_make.Max, node)
[docs] def mutate_EQ(self, node):
return self.mutate_BinOp(_make.EQ, node)
[docs] def mutate_NE(self, node):
return self.mutate_BinOp(_make.NE, node)
[docs] def mutate_LT(self, node):
return self.mutate_BinOp(_make.LT, node)
[docs] def mutate_LE(self, node):
return self.mutate_BinOp(_make.LE, node)
[docs] def mutate_GT(self, node):
return self.mutate_BinOp(_make.GT, node)
[docs] def mutate_GE(self, node):
return self.mutate_BinOp(_make.GE, node)
[docs] def mutate_And(self, node):
return self.mutate_BinOp(_make.And, node)
[docs] def mutate_Or(self, node):
return self.mutate_BinOp(_make.Or, node)
[docs] def mutate_Not(self, node):
a = self.mutate(node.a)
return _make.Not(a)
[docs] def mutate_Var(self, node):
return node
[docs] def mutate_Cast(self, node):
value = self.mutate(node.value)
return _make.Cast(node.dtype, value)
[docs] def mutate_Select(self, node):
condition = _make.Cast("uint1", self.mutate(node.condition))
true_value = convert(self.mutate(node.true_value))
false_value = convert(self.mutate(node.false_value))
return _make.Select(condition, true_value, _make.Cast(true_value.dtype, false_value))
[docs] def mutate_Load(self, node):
buffer_var = self.mutate(node.buffer_var)
index = self.mutate(node.index)
predicate = self.mutate(node.predicate)
return _make.Load(node.dtype, buffer_var, index, predicate)
[docs] def mutate_Ramp(self, node):
base = self.mutate(node.base)
stride = self.mutate(node.stride)
return _make.Ramp(base, stride, node.lanes)
[docs] def mutate_Broadcast(self, node):
value = self.mutate(node.value)
return _make.Broadcast(value, node.lanes)
[docs] def mutate_Call(self, node):
args = []
for arg in node.args:
args.append(self.mutate(arg))
return _make.Call(node.dtype, node.name, args, node.call_type, node.func, node.value_index)
[docs] def mutate_Let(self, node):
var = self.mutate(node.var)
value = self.mutate(node.value)
body = self.mutate(node.body)
return _make.Let(var, value, body)
[docs] def mutate_GetBit(self, node):
a = self.mutate(node.a)
index = self.mutate(node.index)
return _make.GetBit(a, index)
[docs] def mutate_GetSlice(self, node):
a = self.mutate(node.a)
index_left = self.mutate(node.index_left)
index_right = self.mutate(node.index_right)
return _make.GetSlice(a, index_left, index_right)
[docs] def mutate_SetBit(self, node):
a = self.mutate(node.a)
value = self.mutate(node.value)
index = self.mutate(node.index)
return _make.SetBit(a, value, index)
[docs] def mutate_SetSlice(self, node):
a = self.mutate(node.a)
value = self.mutate(node.value)
index_left = self.mutate(node.index_left)
index_right = self.mutate(node.index_right)
return _make.SetSlice(a, value, index_left, index_right)
[docs] def mutate_KernelExpr(self, node):
args = self.mutate(node.args)
return _make.KernelExpr(node.dtype, args, node.name)
[docs] def mutate_StreamExpr(self, node):
args = self.mutate(node.args)
return _make.StreamExpr(node.dtype, args, node.name)
# statements
[docs] def mutate_LetStmt(self, node):
var = self.mutate(node.var)
value = self.mutate(node.value)
body = self.mutate(node.body)
return _make.LetStmt(var, value, body)
[docs] def mutate_AssertStmt(self, node):
condition = self.mutate(node.condition)
message = self.mutate(node.message)
body = self.mutate(node.body)
return _make.AssertStmt(condition, message, body)
[docs] def mutate_ProducerConsumer(self, node):
body = self.mutate(node.body)
return _make.ProducerConsumer(node.func, node.is_producer, body)
[docs] def mutate_ExternModule(self, node):
body = self.mutate(node.body)
return _make.ExternModule(node.attr_key, node.value, body,
node.annotate_keys, node.annotate_values)
[docs] def mutate_For(self, node):
loop_var = self.mutate(node.loop_var)
_min = self.mutate(node.min)
extent = self.mutate(node.extent)
body = self.mutate(node.body)
return _make.For(loop_var, _min, extent, node.for_type, node.device_api, body)
[docs] def mutate_Store(self, node):
buffer_var = self.mutate(node.buffer_var)
index = self.mutate(node.index)
value = self.mutate(node.value)
predicate = self.mutate(node.predicate)
return _make.Store(buffer_var, value, index, predicate)
[docs] def mutate_Allocate(self, node):
buffer_var = self.mutate(node.buffer_var)
extents = self.mutate(node.extents)
condition = self.mutate(node.condition)
body = self.mutate(node.body)
return _make.Allocate(buffer_var, node.dtype, extents, condition, body)
[docs] def mutate_AttrStmt(self, node):
value = self.mutate(node.value)
body = self.mutate(node.body)
return _make.AttrStmt(node.node, node.attr_key, value, body)
[docs] def mutate_Free(self, node):
buffer_var = self.mutate(node.buffer_var)
return _make.Free(buffer_var)
[docs] def mutate_Block(self, node):
first = self.mutate(node.first)
rest = self.mutate(node.rest)
return _make.Block(first, rest)
[docs] def mutate_IfThenElse(self, node):
condition = self.mutate(node.condition)
then_case = self.mutate(node.then_case)
else_case = self.mutate(node.else_case)
return _make.IfThenElse(condition, then_case, else_case)
[docs] def mutate_Evaluate(self, node):
value = self.mutate(node.value)
return _make.Evaluate(value)
[docs] def mutate_KernelDef(self, node):
args = self.mutate(node.args)
body = self.mutate(node.body)
ret_void = self.mutate(node.ret_void)
return _make.KernelDef(args, body, ret_void, node.ret_type, node.name)
[docs] def mutate_KernelStmt(self, node):
args = self.mutate(node.args)
return _make.KernelStmt(args, node.name)
[docs] def mutate_StreamStmt(self, node):
args = self.mutate(node.args)
return _make.StreamStmt(node.dtype, args, node.name)
[docs] def mutate_Return(self, node):
value = self.mutate(node.value)
return _make.Return(value)
[docs] def mutate_Break(self, node):
return _make.Break()
[docs] def mutate_While(self, node):
condition = self.mutate(node.condition)
bdoy = self.mutate(node.body)
return _make.While(condition, body)
[docs] def mutate_Tuple(self, node):
_list = list(node)
_list = self.mutate(_list)
return tuple(_list)
[docs] def mutate_List(self, node):
_len = len(node)
_list = []
for i in range(0, _len):
_list.append(self.mutate(node[i]))
return _list
[docs] def mutate_Function(self, node):
return node