算法学习|最近公共祖先之树上倍增求法

一、问题引入 最近公共祖先(LCA)是求有根树上两点的深度最低的祖先节点,如下图,点5和点2的最近公共祖先为点4,点5和点3的最近公共祖先为点1,点5和点1的最近公共祖先为点1。
算法学习|最近公共祖先之树上倍增求法
文章图片



二、朴素算法 知道LCA定义后考虑最暴力的求解方法,可以先让在下面的点一层一层往上爬,直到两点具有相同深度(高度),之后两点同时一层一层往上爬,直到两点相同,这个相同的点就是最近公共祖先,若有m次询问,该算法时间复杂度为O(nm),显然过于暴力了。
三、树上倍增算法 朴素算法中找祖先节点都是一层一层地找,这样实在太慢了,如果能一次多爬几层就好了。于是我们运用二进制+倍增的思想,以指数级的速度光速向上爬。
在真正向上爬之前需要预处理出几个很有用的数组:fa数组(fa[i][j]表示i结点向上2^j层的祖先节点),depth数组(depth[i]表示i结点深度),log2数组(以2为底i的对数,可用可不用,有些模板需要用到,另外也有现成的函数可以替代)。如果学习过st表可能会发现fa数组很眼熟(学习链接:st表总结),这里的fa数组其实就是改变了下st表中f数组的定义,让它和树结合起来罢了。
depth和log2的预处理就不讲了,主要说下fa数组如何得到。在dfs中结点遍历是从根节点依次向下的,若已知当前节点父节点的fa数组,那当前节点的fa数组就很好求了,用这个式子fa[now][i] = fa[fa[now][i-1]][i-1]从1开始递推就行了,另外要初始化fa[now][0]为now的父节点。由于根节点的fa数组显然是已知的,那么就可以借助dfs从根节点向下更新其它结点的fa数组。
预处理完所有数组后就剩最后一步了,就是指数级地向上爬了。假如有两个结点x和y,x深度大于y的深度,按照朴素算法应该先把x提升至y的深度,假设x和y的深度差为13,将13转化为二进制是1101,也就是8、4和1,令x = fa[x][3],这时x就向上爬了8层,然后令x = fa[x][2],x向上爬了4层,最后令x = fa[x][0],x向上爬了1层,这样就到达了和y相同的深度。然后就是x和y一起向上爬,在这之前要加个特判,如果x和y已经相等了直接返回,不然会出错。x和y一起向上的过程和之前的过程是类似的,不过现在不知道目标点和当前层的深度差,所以要试探性地向上爬,i从大到小循环,如果fa[x][i] != fa[y][i]那么x = fa[x][i], y = fa[y][i],这样最终可以到达lca下面的一层,x或y的父节点就是lca。
这个算法预处理时间复杂度为O(nlogn),单次查询时间复杂度为O(logn)。
【算法学习|最近公共祖先之树上倍增求法】以P3379 【模板】最近公共祖先(LCA)为例放一个模板。

#include #include #include #include #include #include using namespace std; int head[500005], cnt, n, m, s, dep[500005], fa[500005][21]; struct edge { int to, next; }e[1000005]; void add(int u, int v) { e[++cnt].to = v; e[cnt].next = head[u]; head[u] = cnt; }void dfs(int now, int pre) { dep[now] = dep[pre]+1; fa[now][0] = pre; for(int i = 1; i <= 20; i++)//超出范围的祖先都是0号结点 fa[now][i] = fa[fa[now][i-1]][i-1]; for(int i = head[now]; i; i = e[i].next) if(e[i].to != pre) dfs(e[i].to, now); }int lca(int x, int y) { if(dep[x] < dep[y]) swap(x, y); //先把x提升到和y同样的深度 for(int i = 20; i >= 0; i--) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i]; //如果此时x和y已经是lca,接下来无法到达lca下一层,若不特判一定出错 if(x == y) return x; //然后x和y同时提升到lca的下一层 for(int i = 20; i >= 0; i--) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; }signed main() { cin >> n >> m >> s; for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add(u, v), add(v, u); } dfs(s, 0); for(int i = 1; i <= m; i++) { int x, y; scanf("%d%d", &x, &y); printf("%d\n", lca(x, y)); } return 0; }


    推荐阅读