树形动规|[bzoj4543/3522]Hotel

题目大意 一颗n个节点的树。
找三个不同编号的节点,使它们两两间距离相同(一条边距离视作1),求方案数。
在3522的版本中,n<=5000
在4543的版本中,n<=100000
3522 我们来考虑DP
用f[i,j]表示以i为根的子树里与i距离为j的点的个数。
g[i,j]表示在以i为根的子树里,有多少对(x,y)满足x与y到它们lca的距离均为d,且i到它们的lca距离为d-j(容易看出第三个不在i子树内的与i距离为j的点能与这些点匹配成合法解)
接下来用x表示当前节点,y表示一个子节点。f与g都是实时更新的,表示当前做掉的儿子的信息,然后加入一个新的儿子的信息。
一开始先考虑从某一个儿子转移过来,然后此时统计有至少一个点在该儿子子树内时的答案,那么一定是有两个点在该儿子子树内,第三个点就是x。
此时答案就是g[x][0]。
然后接下来枚举其他儿子。
转移式是:
1、f[x,i]+=f[y,i-1]
2、g[x,i-1]+=g[y,i]
3、g[x,i+1]+=f[x,i+1]*f[y,i]
都比较好理解
每做完一个儿子,还要统计答案,就是三个点至少有一个但不是全部在这个新儿子子树里的答案个数。
ans+=f[x,i-1]*g[y,i]+g[x,i+1]*f[y,i]
分别是一个在儿子子树内和两个在儿子子树内的答案。
这样我们是n^2的。
空间当然也是n^2的……
其实想做n^2也可以完全不用这么做,可以通过一些dp得到一个点子树里距离它为i的点对数,然后再dp出不在一个点子树里与它距离为i的点个数,于是就可以算了。
接下来4543的算法在3522的算法上进行改进。
4543 【树形动规|[bzoj4543/3522]Hotel】我们考虑……那个叫长链剖分吗?
就是选择一个深度最大的儿子当重儿子,一个点与重儿子间连的边叫重边,非重儿子称为轻儿子,非重边称为轻边。
从3522的算法可以看出,如果我们用指针来实现,一开始对一个儿子的信息进行位移(f[x][i]=f[y][i-1]和g[x][i]=g[y][i+1]相当于位移一位)可以O(1)实现!
我们不妨把选择的这个儿子就钦点为重儿子,那么接下来只需要对轻边做转移。
设mx[i]表示i子树内深度最大点。
那么一条轻边x到y转移的复杂度是O(dep[mx[y]]-dep[x])
深度最大点肯定是个叶子,不难看出,dep[mx[y]]-dep[x]正好是mx[y]所在重链的长度!
因此转移的总复杂度就是重链长度和,为n。
那么这个算法是线性的!
至于空间分配,当然也只需要o(n)的空间。
给每条重链分配正比于重链长度的空间即可。
具体指针分配空间可看代码实现。
因为我之前也不会这玩意所以我代码基本就是抄的,感谢neither_nor大爷的代码

#include #include #define fo(i,a,b) for(i=a; i<=b; i++) using namespace std; typedef long long ll; const int maxn=100000+10; int h[maxn],go[maxn*2],next[maxn*2]; ll xdl[maxn*5]; ll *f[maxn],*g[maxn]; int mx[maxn],dep[maxn]; ll *gjx=xdl+5; int i,j,k,l,t,n,m,tot; ll ans; int read(){ int x=0,f=1; char ch=getchar(); while (ch<'0'||ch>'9'){ if (ch=='-') f=-1; ch=getchar(); } while (ch>='0'&&ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x*f; } void add(int x,int y){ go[++tot]=y; next[tot]=h[x]; h[x]=tot; } void dfs(int x,int F){ int i,t,y; mx[x]=x; t=h[x]; while (t){ y=go[t]; if (y!=F){ dep[y]=dep[x]+1; dfs(y,x); if (dep[mx[y]]>dep[mx[x]]) mx[x]=mx[y]; } t=next[t]; } t=h[x]; while (t){ y=go[t]; if (y!=F&&(mx[x]!=mx[y]||x==1)){ gjx+=dep[mx[y]]-dep[x]+1; f[mx[y]]=gjx; g[mx[y]]=(gjx+=1); gjx+=(dep[mx[y]]-dep[x])*2+1; } t=next[t]; } } void dp(int x,int F){ int i,j,t,y; t=h[x]; while (t){ y=go[t]; if (y!=F){ dp(y,x); if (mx[y]==mx[x]){ f[x]=f[y]-1; g[x]=g[y]+1; } } t=next[t]; } f[x][0]=1; ans+=g[x][0]; t=h[x]; while (t){ y=go[t]; if (y!=F&&mx[x]!=mx[y]){ fo(j,0,dep[mx[y]]-dep[x]) ans+=f[x][j-1]*g[y][j]+g[x][j+1]*f[y][j]; fo(j,0,dep[mx[y]]-dep[x]){ f[x][j]+=f[y][j-1]; g[x][j-1]+=g[y][j]; g[x][j+1]+=f[x][j+1]*f[y][j]; } } t=next[t]; } } int main(){ n=read(); fo(i,1,n-1){ j=read(); k=read(); add(j,k); add(k,j); } dep[1]=1; dfs(1,0); dp(1,0); printf("%lld\n",ans); }

    推荐阅读