Minimum Spanning Tree (MST)

Minimum Spanning Tree (MST)

Proving correctness of any greedy algorithm is always a challenge. And of course it’s more challenging to find them in the first place - just look at how they’re named.

But it’s not relatively easy to code it up during a tech interview. And usually no one cares to discuss around the proof of correctness.

Tree vs. Graph

Tree is a connected graph without loops.

Prim - build vertices

  1. Start with a single vertex.
  2. Grow the tree by one edge. Find the minimum-weight edge that connects the tree to an outside vertex.
  3. Repeat step 2 until all vertices are in the tree. image

Kruskal - builds edges

  1. Sort all the edges in ascending order of its weight.
  2. Choose the smallest edge. Check if it forms a cycle with any existing tree. If cycle is not formed, include this edge. Else, discard it.
  3. Repeat step 2 until there are (V-1) edges chosen. image

Cycles - Disjoint Set / Union-Find

In Kruskal’s step 2. We need to detect cycles. Union-Find is an efficient algorithm in this scenario, instead of normal dfs with a visited set.

Key Ideas

  1. Each set is represented by a single element in the set. Every other elements are made to pointing to the element;
  2. Given a random element, we can find its set by looking at who it’s pointing to.
  3. To union two sets, just make one representative element pointing to the other.

Key Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*
 * Goal: given a random element, find which set it belongs to.
 * Strategy:
 *  1. keep looking for its parent until reaching the root.
 *  2. while finding its way up, make every touched node pointing to root. This makes following look-ups amortized O(1).
 * Recursive definition:
 *  input: current node
 *  output: its head
 */
private Integer findHead(Integer id) {
    //this is the head
    if (parent[id] == 0) {
        return id;
    } else {  //trace back to head
        //by recursive definition, findHead(parent[id]) returns parent's head
        //now setting it to current node's head
        parent[id] = findHead(parent[id]);
        return parent[id];
    }
}

Full Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//Tree implementation of union find data structure
public class DisjointSet<T> {

    //next id of each element
    private Integer idGen = 0;

    //index: id, value: parent id
    private int[] parent;

    //size of each group
    private int[] sizeOf;

    //key: object in set, value: id of object
    private Map<T, Integer> map = new HashMap<>();

    //make set
    public DisjointSet (Set<T> set){

        if (set == null) {
            throw new IllegalArgumentException("set cannot be null.");
        }

        //create parent array
        int size = set.size();
        parent = new int[size + 1];  //use 0 as pointing to self

        //index: group id, value: size of this group
        sizeOf = new int[size + 1];

        //initially, each element is a size 1 group
        for (int i = 1; i < sizeOf.length; i++) {
            sizeOf[i] = 1;
        }

        for (T element : set) {
            map.put(element, ++idGen);  //first id will be 1
        }
    }

    /**
     * union operation with two elements
     * union the set which a belongs to, with the set which b belongs to
     * @param a object to union with
     * @param b object to union with
     */
    public void unionElement(T a, T b) {
        int groupA = find(a);
        int groupB = find(b);

        union(groupA, groupB);
    }

    /*
     * union operation with two sets
     * @param a set representative id
     * @param b set representative id
     */
    private void union(int a, int b) {

        //same group, no operation
        if (a == b) {
            return;
        }

        //make the smaller set point to the bigger set
        if (sizeOf[a] < sizeOf[b]) {
            parent[a] = b;

            //update new size
            sizeOf[b] = sizeOf[a] + sizeOf[b];
        } else {
            parent[b] = a;

            //update new size
            sizeOf[a] = sizeOf[a] + sizeOf[b];
        }
    }

    /**
     * find operation
     * @param target Object to find
     * @return the group id of target
     */
    public Integer find(T target) {

        int id = map.get(target);
        return findHead(id);

    }

    //find head of this id
    private Integer findHead(Integer id) {
        //this is the head
        if (parent[id] == 0) {
            return id;
        } else {  //trace back to head
            parent[id] = findHead(parent[id]);
            return parent[id];
        }
    }