HeteroCL Tutorial : K-means Clustering Algorithm

Author: Yi-Hsiang Lai (seanlatias@github), Ziyan Feng

This is the K-means clustering algorithm written in Heterocl.

import numpy as np
import heterocl as hcl
import time
import random

Define the number of the clustering means as K, the number of points as N, the number of dimensions as dim, and the number of iterations as niter

K = 16
N = 320
dim = 32
niter = 200

hcl.init()

Main Algorithm

def top(target=None):
    points = hcl.placeholder((N, dim))
    means = hcl.placeholder((K, dim))

    def kmeans(points, means):
        def loop_kernel(labels):
            # assign cluster
            with hcl.for_(0, N, name="N") as n:
                min_dist = hcl.scalar(100000)
                with hcl.for_(0, K) as k:
                    dist = hcl.scalar(0)
                    with hcl.for_(0, dim) as d:
                        dist_ = points[n, d]-means[k, d]
                        dist.v += dist_ * dist_
                    with hcl.if_(dist.v < min_dist.v):
                        min_dist.v = dist.v
                        labels[n] = k
            # update mean
            num_k = hcl.compute((K,), lambda x: 0)
            sum_k = hcl.compute((K, dim), lambda x, y: 0)
            def calc_sum(n):
                num_k[labels[n]] += 1
                with hcl.for_(0, dim) as d:
                    sum_k[labels[n], d] += points[n, d]
            hcl.mutate((N,), lambda n: calc_sum(n), "calc_sum")
            hcl.update(means,
                    lambda k, d: sum_k[k, d]//num_k[k], "update_mean")

        labels = hcl.compute((N,), lambda x: 0)
        hcl.mutate((niter,), lambda _: loop_kernel(labels), "main_loop")
        return labels

    # create schedule and apply compute customization
    s = hcl.create_schedule([points, means], kmeans)
    main_loop = kmeans.main_loop
    update_mean = main_loop.update_mean
    s[main_loop].pipeline(main_loop.N)
    s[main_loop.calc_sum].unroll(main_loop.calc_sum.axis[0])
    fused = s[update_mean].fuse(update_mean.axis[0], update_mean.axis[1])
    s[update_mean].unroll(fused)
    return hcl.build(s, target=target)

f = top()

points_np = np.random.randint(100, size=(N, dim))
labels_np = np.zeros(N)
means_np = points_np[random.sample(range(N), K), :]

hcl_points = hcl.asarray(points_np, dtype=hcl.Int())
hcl_means = hcl.asarray(means_np, dtype=hcl.Int())
hcl_labels = hcl.asarray(labels_np)

start = time.time()
f(hcl_points, hcl_means, hcl_labels)
total_time = time.time() - start
print("Kernel time (s): {:.2f}".format(total_time))

print("All points:")
print(hcl_points)
print("Final cluster:")
print(hcl_labels)
print("The means:")
print(hcl_means)

from kmeans_golden import kmeans_golden
kmeans_golden(niter, K, N, dim, np.concatenate((points_np,
    np.expand_dims(labels_np, axis=1)), axis=1), means_np)
assert np.allclose(hcl_means.asnumpy(), means_np)

Out:

Kernel time (s): 0.04
All points:
[[99 85 64 ... 54 94 83]
 [43 26 32 ...  1 57 39]
 [15 93 93 ... 12 41 72]
 ...
 [23 61 90 ... 69 78 79]
 [46 89 26 ... 58 87 30]
 [11 12 86 ... 99 52  4]]
Final cluster:
[14  3  8  7  2  7  4 11 13 14  1  6  1  1  9  5 15 10  0  3  6  7 11 10
 12  9 10  4  7  2  6  1  4  5  6 15 13  4 12  8  5  6  1  5  1 13  2 11
 11  4  2  0 11 11 13  5  6 14  1  3  3 12 11  6  9 14 14 10  9 13  0 10
  5  5  8  8  9 10 12  0  9  4 15  9  9  7  0  4 11 12 12  7 13 10  2 13
 13  9  4 11  5 15 12  7  3  0 15  1  4  9  8  4  6  7  3  8  0 13  0 15
  9  0 10  6 14  0 13  0  5  2 14 11  4  2  5 15  2 13  2  3 14 10 11  2
  5 13 10  9 15  3  4 11  2  0  2  2 11 14  0  2 13  9  5 12  4 10  7 10
  5  2  9 15  6  2  9  3  5 11 14  3  2 12  5  7  6  4  3  4  2  5  7  2
 10  1  4  9  1  4  6 13 11  3  2 11  1 14  5  4 10 10  1  5 10  6  5  3
  8 13 13 13  2  0  2 14  8 13  5 15 13 10  9  3  4 13 11  6  6  2  2  5
  3 10  1  9 15 15  8  6  3  7  9 14  7  4 11 12  4  5  4  7 12  7  9  7
 13  7  2  0 14  7 14  8 10  3  1 10 11  7  6  7 12 10 12  1 10  4 15  4
  7  9 10  6 11  0  3 13  2 11 15  2  0  9  2  6 13 10  8 10  5  9  1  5
  4  8  6  6  1 14  0 14]
The means:
[[60 64 39 46 39 59 67 57 64 62 20 38 50 46 35 41 56 41 67 34 50 40 44 34
  62 74 60 64 62 62 72 48]
 [64 47 20 55 71 43 56 42 44 45 45 30 32 60 60 58 32 45 54 58 58 62 64 78
  49 46 48 73 62 63 32 50]
 [37 55 46 45 63 55 33 54 65 27 70 43 55 59 53 56 45 41 42 29 42 50 46 35
  54 48 48 75 38 51 32 37]
 [65 37 42 26 34 36 31 45 51 58 67 74 40 41 31 54 33 61 37 41 61 58 53 54
  69 73 33 38 69 48 55 35]
 [19 26 45 68 51 50 32 61 52 67 64 56 59 41 42 51 77 56 56 51 49 62 37 39
  34 44 61 50 37 61 60 51]
 [43 41 66 34 57 41 53 33 28 49 58 41 75 48 49 20 46 31 50 49 46 56 61 32
  44 46 54 38 62 29 58 41]
 [31 47 60 31 48 52 57 35 51 57 31 70 19 36 42 48 46 42 61 56 24 53 38 50
  44 50 41 61 36 67 45 47]
 [26 60 52 32 56 57 49 61 48 35 44 55 43 61 26 55 71 49 43 56 48 23 61 64
  36 34 23 32 56 67 64 48]
 [67 73 60 50 57 32 67 34 45 57 37 52 51 23 46 59 59 62 64 64 77 61 65 34
  62 21 26 47 68 30 17 68]
 [60 34 22 53 34 41 55 67 50 54 55 44 62 57 48 45 43 75 50 36 44 46 75 54
  42 47 58 64 52 31 46 45]
 [66 49 58 62 43 65 43 59 55 44 66 40 61 35 62 39 50 65 50 47 36 42 52 36
  53 38 30 52 61 61 63 72]
 [48 48 49 57 59 46 51 31 66 50 55 30 51 65 53 44 43 52 46 63 58 37 43 76
  66 45 71 22 54 52 49 55]
 [52 31 69 58 64 70 56 54 36 48 38 26 24 35 37 54 58 41 29 46 62 52 54 53
  35 44 58 37 21 28 58 62]
 [41 66 50 41 42 54 80 29 51 44 42 66 55 61 42 46 38 71 35 56 50 44 45 39
  52 68 68 43 41 41 28 73]
 [54 41 69 69 32 42 33 62 25 45 31 47 44 47 69 64 55 28 50 53 59 67 43 64
  61 78 51 35 64 73 45 46]
 [44 62 24 72 25 37 40 43 56 71 55 53 53 34 55 71 58 38 48 48 65 18 63 55
  56 58 55 20 47 57 41 49]]

Total running time of the script: ( 0 minutes 38.937 seconds)

Gallery generated by Sphinx-Gallery