.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_samples_kmeans_kmeans_main.py: HeteroCL Tutorial : K-means Clustering Algorithm ================================================ **Author**: Yi-Hsiang Lai (seanlatias@github), Ziyan Feng This is the K-means clustering algorithm written in Heterocl. .. code-block:: default import numpy as np import heterocl as hcl import time import random Define the number of the clustering means as K, the number of points as N, the number of dimensions as dim, and the number of iterations as niter .. code-block:: default K = 16 N = 320 dim = 32 niter = 200 hcl.init() Main Algorithm ============== .. code-block:: default def top(target=None): points = hcl.placeholder((N, dim)) means = hcl.placeholder((K, dim)) def kmeans(points, means): def loop_kernel(labels): # assign cluster with hcl.for_(0, N, name="N") as n: min_dist = hcl.scalar(100000) with hcl.for_(0, K) as k: dist = hcl.scalar(0) with hcl.for_(0, dim) as d: dist_ = points[n, d]-means[k, d] dist.v += dist_ * dist_ with hcl.if_(dist.v < min_dist.v): min_dist.v = dist.v labels[n] = k # update mean num_k = hcl.compute((K,), lambda x: 0) sum_k = hcl.compute((K, dim), lambda x, y: 0) def calc_sum(n): num_k[labels[n]] += 1 with hcl.for_(0, dim) as d: sum_k[labels[n], d] += points[n, d] hcl.mutate((N,), lambda n: calc_sum(n), "calc_sum") hcl.update(means, lambda k, d: sum_k[k, d]//num_k[k], "update_mean") labels = hcl.compute((N,), lambda x: 0) hcl.mutate((niter,), lambda _: loop_kernel(labels), "main_loop") return labels # create schedule and apply compute customization s = hcl.create_schedule([points, means], kmeans) main_loop = kmeans.main_loop update_mean = main_loop.update_mean s[main_loop].pipeline(main_loop.N) s[main_loop.calc_sum].unroll(main_loop.calc_sum.axis[0]) fused = s[update_mean].fuse(update_mean.axis[0], update_mean.axis[1]) s[update_mean].unroll(fused) return hcl.build(s, target=target) f = top() points_np = np.random.randint(100, size=(N, dim)) labels_np = np.zeros(N) means_np = points_np[random.sample(range(N), K), :] hcl_points = hcl.asarray(points_np, dtype=hcl.Int()) hcl_means = hcl.asarray(means_np, dtype=hcl.Int()) hcl_labels = hcl.asarray(labels_np) start = time.time() f(hcl_points, hcl_means, hcl_labels) total_time = time.time() - start print("Kernel time (s): {:.2f}".format(total_time)) print("All points:") print(hcl_points) print("Final cluster:") print(hcl_labels) print("The means:") print(hcl_means) from kmeans_golden import kmeans_golden kmeans_golden(niter, K, N, dim, np.concatenate((points_np, np.expand_dims(labels_np, axis=1)), axis=1), means_np) assert np.allclose(hcl_means.asnumpy(), means_np) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Kernel time (s): 0.04 All points: [[99 85 64 ... 54 94 83] [43 26 32 ... 1 57 39] [15 93 93 ... 12 41 72] ... [23 61 90 ... 69 78 79] [46 89 26 ... 58 87 30] [11 12 86 ... 99 52 4]] Final cluster: [14 3 8 7 2 7 4 11 13 14 1 6 1 1 9 5 15 10 0 3 6 7 11 10 12 9 10 4 7 2 6 1 4 5 6 15 13 4 12 8 5 6 1 5 1 13 2 11 11 4 2 0 11 11 13 5 6 14 1 3 3 12 11 6 9 14 14 10 9 13 0 10 5 5 8 8 9 10 12 0 9 4 15 9 9 7 0 4 11 12 12 7 13 10 2 13 13 9 4 11 5 15 12 7 3 0 15 1 4 9 8 4 6 7 3 8 0 13 0 15 9 0 10 6 14 0 13 0 5 2 14 11 4 2 5 15 2 13 2 3 14 10 11 2 5 13 10 9 15 3 4 11 2 0 2 2 11 14 0 2 13 9 5 12 4 10 7 10 5 2 9 15 6 2 9 3 5 11 14 3 2 12 5 7 6 4 3 4 2 5 7 2 10 1 4 9 1 4 6 13 11 3 2 11 1 14 5 4 10 10 1 5 10 6 5 3 8 13 13 13 2 0 2 14 8 13 5 15 13 10 9 3 4 13 11 6 6 2 2 5 3 10 1 9 15 15 8 6 3 7 9 14 7 4 11 12 4 5 4 7 12 7 9 7 13 7 2 0 14 7 14 8 10 3 1 10 11 7 6 7 12 10 12 1 10 4 15 4 7 9 10 6 11 0 3 13 2 11 15 2 0 9 2 6 13 10 8 10 5 9 1 5 4 8 6 6 1 14 0 14] The means: [[60 64 39 46 39 59 67 57 64 62 20 38 50 46 35 41 56 41 67 34 50 40 44 34 62 74 60 64 62 62 72 48] [64 47 20 55 71 43 56 42 44 45 45 30 32 60 60 58 32 45 54 58 58 62 64 78 49 46 48 73 62 63 32 50] [37 55 46 45 63 55 33 54 65 27 70 43 55 59 53 56 45 41 42 29 42 50 46 35 54 48 48 75 38 51 32 37] [65 37 42 26 34 36 31 45 51 58 67 74 40 41 31 54 33 61 37 41 61 58 53 54 69 73 33 38 69 48 55 35] [19 26 45 68 51 50 32 61 52 67 64 56 59 41 42 51 77 56 56 51 49 62 37 39 34 44 61 50 37 61 60 51] [43 41 66 34 57 41 53 33 28 49 58 41 75 48 49 20 46 31 50 49 46 56 61 32 44 46 54 38 62 29 58 41] [31 47 60 31 48 52 57 35 51 57 31 70 19 36 42 48 46 42 61 56 24 53 38 50 44 50 41 61 36 67 45 47] [26 60 52 32 56 57 49 61 48 35 44 55 43 61 26 55 71 49 43 56 48 23 61 64 36 34 23 32 56 67 64 48] [67 73 60 50 57 32 67 34 45 57 37 52 51 23 46 59 59 62 64 64 77 61 65 34 62 21 26 47 68 30 17 68] [60 34 22 53 34 41 55 67 50 54 55 44 62 57 48 45 43 75 50 36 44 46 75 54 42 47 58 64 52 31 46 45] [66 49 58 62 43 65 43 59 55 44 66 40 61 35 62 39 50 65 50 47 36 42 52 36 53 38 30 52 61 61 63 72] [48 48 49 57 59 46 51 31 66 50 55 30 51 65 53 44 43 52 46 63 58 37 43 76 66 45 71 22 54 52 49 55] [52 31 69 58 64 70 56 54 36 48 38 26 24 35 37 54 58 41 29 46 62 52 54 53 35 44 58 37 21 28 58 62] [41 66 50 41 42 54 80 29 51 44 42 66 55 61 42 46 38 71 35 56 50 44 45 39 52 68 68 43 41 41 28 73] [54 41 69 69 32 42 33 62 25 45 31 47 44 47 69 64 55 28 50 53 59 67 43 64 61 78 51 35 64 73 45 46] [44 62 24 72 25 37 40 43 56 71 55 53 53 34 55 71 58 38 48 48 65 18 63 55 56 58 55 20 47 57 41 49]] .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 38.937 seconds) .. _sphx_glr_download_samples_kmeans_kmeans_main.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: kmeans_main.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: kmeans_main.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_