1 minute read

What is Union Find

Union find data structure, also called disjoint-set data structure. is a data structure that stores a collection of disjoint (non-overlapping) sets.

It provides operations for adding new sets, merging sets (replace them by their union) and find a representative member of a set.

these operation makes it possible to find out efficiently if any two elements are in the same set.

Union Find Implementation

public Interface IUnionFind<T> {
    void merge(T o1, T o2);
    boolean isSameSet(T o1, T o2);
}

public class UnionFind2<T> implements IUnionFind<T> {

	public static void main(String[] args) {
		Integer[] input = new Integer[] { 1, 2, 3, 4, 5 };

		UnionFind2<Integer> uf = new UnionFind2<>(Arrays.asList(input));

		System.out.println("1 2 same set?");
		System.out.println(uf.isSameSet(1, 2));

		System.out.println("merge 1, 2");
		uf.merge(1, 2);

		System.out.println("1 2 same set?");
		System.out.println(uf.isSameSet(1, 2));

		System.out.println("1 3 same set?");
		System.out.println(uf.isSameSet(1, 3));

		System.out.println("merge 1, 3");
		uf.merge(3, 1);

		System.out.println("merge 3, 4");
		uf.merge(3, 4);

		System.out.println("1 4 same set?");
		System.out.println(uf.isSameSet(1, 4));

	}

	Map<T, T> parentMap;
	Map<T, Integer> sizeMap;

	public UnionFind2(List<T> list) {
		parentMap = new HashMap<>();
		sizeMap = new HashMap<>();

		for (T ele : list) {
			parentMap.put(ele, ele);
			sizeMap.put(ele, 1);
		}
	}

	@Override
	public void merge(T o1, T o2) {
		T p1 = findParent(o1);
		T p2 = findParent(o2);
		if (p1 == p2) {
			return;
		}

		int size1 = sizeMap.get(p1);
		int size2 = sizeMap.get(p2);
		if (size1 > size2) {
			parentMap.put(p2, p1);
			sizeMap.put(p1, sizeMap.get(p1) + sizeMap.get(p2));
			sizeMap.remove(p2);
		} else {
			parentMap.put(p1, p2);
			sizeMap.put(p2, sizeMap.get(p1) + sizeMap.get(p2));
			sizeMap.remove(p1);
		}
	}

	@Override
	public boolean isSameSet(T o1, T o2) {
		T p1 = findParent(o1);
		T p2 = findParent(o2);
		return p1 == p2;
	}

	private T findParent(T o) {
		Set<T> set = new HashSet<>();

		T parent = parentMap.get(o);

		while (o != parent) {
			set.add(o);
			o = parent;
			parent = parentMap.get(o);
		}

		for (T ele : set) {
			parentMap.put(ele, o);
		}

		return o;
	}
}