牛客练习赛51 F ABCBA 可持久化线段树

【牛客练习赛51 F ABCBA 可持久化线段树】F ABCBA

解法:我们可以用可持久化线段树维护某点到根的所有信息,那么每次查询,我们找到 u v 的 lca,用线段树分别查询[lca, u],[lca, v]的区间并进行合并就是得到答案,问题转化为线段树维护子序列为ABCBA的数量,我们分别维护区间子序列A,AB,ABC,ABCB,ABCBA,B,BC,BCB,BCBA,C,CB,CBA,BA的数量,每次区间合并,用ls表示左儿子,rs表示右儿子,不难发现 ABCBA = ABCBA[ls] + ABCBA[rs] + A[ls] * BCBA[rs] + AB[ls] *CBA[rs] + ABC[ls] * BA[rs] + ABCB[ls] * A[rs],其他信息亦同理,那么这个题就变成码力题了,思维难度几乎为0
#include #define ll long long using namespace std; const int maxn = 3e4 + 10, mod = 10007; vector G[maxn]; char s[maxn]; void add(int &x, int y) { x += y; while (x >= mod) x -= mod; while (x < 0) x += mod; } struct node { int cat, A, AB, ABC, ABCB, B, BC, BCB, BCBA, C, CB, CBA, BA; node operator+(const node &t) const { node tmp; tmp.cat = (cat + t.cat + A * t.BCBA + AB * t.CBA + ABC * t.BA + ABCB * t.A) % mod; tmp.A = (A + t.A) % mod; tmp.AB = (AB + t.AB + A * t.B) % mod; tmp.ABC = (ABC + t.ABC + A * t.BC + AB * t.C) % mod; tmp.ABCB = (ABCB + t.ABCB + A * t.BCB + AB * t.CB + ABC * t.B) % mod; tmp.B = (B + t.B) % mod; tmp.BC = (BC + t.BC + B * t.C) % mod; tmp.BCB = (BCB + t.BCB + B * t.CB + BC * t.B) % mod; tmp.BCBA = (BCBA + t.BCBA + B * t.CBA + BC * t.BA + BCB * t.A) % mod; tmp.C = (C + t.C) % mod; tmp.CB = (CB + t.CB + C * t.B) % mod; tmp.CBA = (CBA + t.CBA + C * t.BA + CB * t.A) % mod; tmp.BA = (BA + t.BA + B * t.A) % mod; return tmp; } } tree[maxn * 20]; int rt[maxn], ls[maxn * 20], rs[maxn * 20], cnt, f[maxn][20], dep[maxn], n; #define mid (l + r) / 2 void up(int &o, int pre, int l, int r, int k, char c) { o = ++cnt; ls[o] = ls[pre]; rs[o] = rs[pre]; if (l == r) { if (c == 'A') tree[o].A = 1; else if (c == 'B') tree[o].B = 1; else if (c == 'C') tree[o].C = 1; return; } if (k <= mid) up(ls[o], ls[pre], l, mid, k, c); else up(rs[o], rs[pre], mid + 1, r, k, c); tree[o] = tree[ls[o]] + tree[rs[o]]; } void dfs(int u, int fa) { f[u][0] = fa; dep[u] = dep[fa] + 1; for (int i = 1; i < 18; i++) f[u][i] = f[f[u][i - 1]][i - 1]; up(rt[u], rt[fa], 1, n, dep[u], s[u]); for (auto v : G[u]) if (v != fa) dfs(v, u); } int LCA(int u, int v) { if (dep[u] < dep[v]) swap(u, v); for (int i = 17; ~i; i--) if (dep[f[u][i]] >= dep[v]) u = f[u][i]; if (u == v) return u; for (int i = 17; ~i; i--) if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i]; return f[u][0]; } node qu(int o, int l, int r, int ql, int qr) { if (l >= ql && r <= qr) return tree[o]; if (qr <= mid) return qu(ls[o], l, mid, ql, qr); else if (ql > mid) return qu(rs[o], mid + 1, r, ql, qr); else return qu(ls[o], l, mid, ql, qr) + qu(rs[o], mid + 1, r, ql, qr); } int main() { int m, u, v; scanf("%d%d", &n, &m); scanf("%s", s + 1); for (int i = 1; i < n; i++) { scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); while (m--) { scanf("%d%d", &u, &v); int lca = LCA(u, v); int ans = 0; if (u == lca || v == lca) { node tmp; if (u != lca) tmp = qu(rt[u], 1, n, dep[lca], dep[u]); else tmp = qu(rt[v], 1, n, dep[lca], dep[v]); add(ans, tmp.cat); } else { node t1 = qu(rt[u], 1, n, dep[lca], dep[u]); node t2 = qu(rt[v], 1, n, dep[lca] + 1, dep[v]); ans = (t1.cat + t2.cat + t1.A * t2.BCBA + t1.BA * t2.CBA + t1.CBA * t2.BA + t1.BCBA * t2.A) % mod; } printf("%d\n", ans); } }

    推荐阅读