class DisjointSet(object): index = list() depth = list() def __init__(self, n): self.index = list(range(n)) self.depth = [0] * n def tree_num(self): num = 0 for i in range(len(self.index)): if self.index[i] == i: num += 1 return num def find(self, i): if (self.index[i] != i): self.index[i] = self.find(self.index[i]) # optimize path return self.index[i] def union(self, i, j): index_i = self.find(i) index_j = self.find(j) if index_i != index_j: if self.depth[i] > self.depth[j]: self.index[index_j] = index_i elif self.depth[i] < self.depth[j]: self.index[index_i] = index_j else: self.index[index_i] = index_j self.depth[index_j] += 1 def __str__(self): return 'items: {}\nindex: {}\ndepth: {}'.format(list(range(len(self.index))), self.index, self.depth)