Note
Click here to download the full example code
HeteroCL Tutorial : K-means Clustering Algorithm¶
Author: Yi-Hsiang Lai (seanlatias@github), Ziyan Feng
This is the K-means clustering algorithm written in Heterocl.
import numpy as np
import heterocl as hcl
import time
import random
Define the number of the clustering means as K, the number of points as N, the number of dimensions as dim, and the number of iterations as niter
K = 16
N = 320
dim = 32
niter = 200
hcl.init()
Main Algorithm¶
def top(target=None):
points = hcl.placeholder((N, dim))
means = hcl.placeholder((K, dim))
def kmeans(points, means):
def loop_kernel(labels):
# assign cluster
with hcl.for_(0, N, name="N") as n:
min_dist = hcl.scalar(100000)
with hcl.for_(0, K) as k:
dist = hcl.scalar(0)
with hcl.for_(0, dim) as d:
dist_ = points[n, d]-means[k, d]
dist.v += dist_ * dist_
with hcl.if_(dist.v < min_dist.v):
min_dist.v = dist.v
labels[n] = k
# update mean
num_k = hcl.compute((K,), lambda x: 0)
sum_k = hcl.compute((K, dim), lambda x, y: 0)
def calc_sum(n):
num_k[labels[n]] += 1
with hcl.for_(0, dim) as d:
sum_k[labels[n], d] += points[n, d]
hcl.mutate((N,), lambda n: calc_sum(n), "calc_sum")
hcl.update(means,
lambda k, d: sum_k[k, d]//num_k[k], "update_mean")
labels = hcl.compute((N,), lambda x: 0)
hcl.mutate((niter,), lambda _: loop_kernel(labels), "main_loop")
return labels
# create schedule and apply compute customization
s = hcl.create_schedule([points, means], kmeans)
main_loop = kmeans.main_loop
update_mean = main_loop.update_mean
s[main_loop].pipeline(main_loop.N)
s[main_loop.calc_sum].unroll(main_loop.calc_sum.axis[0])
fused = s[update_mean].fuse(update_mean.axis[0], update_mean.axis[1])
s[update_mean].unroll(fused)
return hcl.build(s, target=target)
f = top()
points_np = np.random.randint(100, size=(N, dim))
labels_np = np.zeros(N)
means_np = points_np[random.sample(range(N), K), :]
hcl_points = hcl.asarray(points_np, dtype=hcl.Int())
hcl_means = hcl.asarray(means_np, dtype=hcl.Int())
hcl_labels = hcl.asarray(labels_np)
start = time.time()
f(hcl_points, hcl_means, hcl_labels)
total_time = time.time() - start
print("Kernel time (s): {:.2f}".format(total_time))
print("All points:")
print(hcl_points)
print("Final cluster:")
print(hcl_labels)
print("The means:")
print(hcl_means)
from kmeans_golden import kmeans_golden
kmeans_golden(niter, K, N, dim, np.concatenate((points_np,
np.expand_dims(labels_np, axis=1)), axis=1), means_np)
assert np.allclose(hcl_means.asnumpy(), means_np)
Out:
Kernel time (s): 0.04
All points:
[[99 85 64 ... 54 94 83]
[43 26 32 ... 1 57 39]
[15 93 93 ... 12 41 72]
...
[23 61 90 ... 69 78 79]
[46 89 26 ... 58 87 30]
[11 12 86 ... 99 52 4]]
Final cluster:
[14 3 8 7 2 7 4 11 13 14 1 6 1 1 9 5 15 10 0 3 6 7 11 10
12 9 10 4 7 2 6 1 4 5 6 15 13 4 12 8 5 6 1 5 1 13 2 11
11 4 2 0 11 11 13 5 6 14 1 3 3 12 11 6 9 14 14 10 9 13 0 10
5 5 8 8 9 10 12 0 9 4 15 9 9 7 0 4 11 12 12 7 13 10 2 13
13 9 4 11 5 15 12 7 3 0 15 1 4 9 8 4 6 7 3 8 0 13 0 15
9 0 10 6 14 0 13 0 5 2 14 11 4 2 5 15 2 13 2 3 14 10 11 2
5 13 10 9 15 3 4 11 2 0 2 2 11 14 0 2 13 9 5 12 4 10 7 10
5 2 9 15 6 2 9 3 5 11 14 3 2 12 5 7 6 4 3 4 2 5 7 2
10 1 4 9 1 4 6 13 11 3 2 11 1 14 5 4 10 10 1 5 10 6 5 3
8 13 13 13 2 0 2 14 8 13 5 15 13 10 9 3 4 13 11 6 6 2 2 5
3 10 1 9 15 15 8 6 3 7 9 14 7 4 11 12 4 5 4 7 12 7 9 7
13 7 2 0 14 7 14 8 10 3 1 10 11 7 6 7 12 10 12 1 10 4 15 4
7 9 10 6 11 0 3 13 2 11 15 2 0 9 2 6 13 10 8 10 5 9 1 5
4 8 6 6 1 14 0 14]
The means:
[[60 64 39 46 39 59 67 57 64 62 20 38 50 46 35 41 56 41 67 34 50 40 44 34
62 74 60 64 62 62 72 48]
[64 47 20 55 71 43 56 42 44 45 45 30 32 60 60 58 32 45 54 58 58 62 64 78
49 46 48 73 62 63 32 50]
[37 55 46 45 63 55 33 54 65 27 70 43 55 59 53 56 45 41 42 29 42 50 46 35
54 48 48 75 38 51 32 37]
[65 37 42 26 34 36 31 45 51 58 67 74 40 41 31 54 33 61 37 41 61 58 53 54
69 73 33 38 69 48 55 35]
[19 26 45 68 51 50 32 61 52 67 64 56 59 41 42 51 77 56 56 51 49 62 37 39
34 44 61 50 37 61 60 51]
[43 41 66 34 57 41 53 33 28 49 58 41 75 48 49 20 46 31 50 49 46 56 61 32
44 46 54 38 62 29 58 41]
[31 47 60 31 48 52 57 35 51 57 31 70 19 36 42 48 46 42 61 56 24 53 38 50
44 50 41 61 36 67 45 47]
[26 60 52 32 56 57 49 61 48 35 44 55 43 61 26 55 71 49 43 56 48 23 61 64
36 34 23 32 56 67 64 48]
[67 73 60 50 57 32 67 34 45 57 37 52 51 23 46 59 59 62 64 64 77 61 65 34
62 21 26 47 68 30 17 68]
[60 34 22 53 34 41 55 67 50 54 55 44 62 57 48 45 43 75 50 36 44 46 75 54
42 47 58 64 52 31 46 45]
[66 49 58 62 43 65 43 59 55 44 66 40 61 35 62 39 50 65 50 47 36 42 52 36
53 38 30 52 61 61 63 72]
[48 48 49 57 59 46 51 31 66 50 55 30 51 65 53 44 43 52 46 63 58 37 43 76
66 45 71 22 54 52 49 55]
[52 31 69 58 64 70 56 54 36 48 38 26 24 35 37 54 58 41 29 46 62 52 54 53
35 44 58 37 21 28 58 62]
[41 66 50 41 42 54 80 29 51 44 42 66 55 61 42 46 38 71 35 56 50 44 45 39
52 68 68 43 41 41 28 73]
[54 41 69 69 32 42 33 62 25 45 31 47 44 47 69 64 55 28 50 53 59 67 43 64
61 78 51 35 64 73 45 46]
[44 62 24 72 25 37 40 43 56 71 55 53 53 34 55 71 58 38 48 48 65 18 63 55
56 58 55 20 47 57 41 49]]
Total running time of the script: ( 0 minutes 38.937 seconds)