0%

总结 虚树

某些树形dp的数据范围过大,有多次询问,每次询问选定关键点(关键点总和在 1e5~1e6 的范围)

例题:P2495 [SDOI2011]消耗战

此时会有很多点不需要参与dp,只要另外建出一棵包含有用点的虚树来

有一种增量构造的方法,难写难记(几个月前写的还没调出来),所以我干脆只学本文这种方法

推荐 shadowice1984 的博客

思路

我们只要掌握了一棵树的dfs序,就可以模拟整个dfs过程

对于关键点序列 a,dfs 一遍,求出每个点的 dfs 序(入栈和出栈的)

1
2
3
4
5
6
7
void dfs(int u, int last) {
pu[u] = ++dfu;
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != last) {
dfs(v, u);
po[u] = ++dfu;
}

将关键点按dfs序排序,相邻两个点取 LCA,加入 a 序列,再把树根 1 加入

注意开个 bool 数组去重

最后把整个序列 a 复制一份,全部取负,加入序列

此时序列 a 中全是虚树节点,正数代表入栈点,负数代表出栈点

将 a 序列按照 dfs 序排序,cmp 中要区分出入栈:

1
2
3
inline bool cmp(int x, int y) {
return (x > 0 ? pu[x] : po[-x]) < (y > 0 ? pu[y] : po[-y]);
}

现在整个虚树的 dfs 序已经知道了,如果是简单的树形dp甚至不需要建树,直接模拟 dfs 遍历:

开一个栈,遍历 a 序列,如果是正数即入栈,进行 dfs 递归前的操作

负数说明出栈,进行 dfs 递归完各个子树后的操作

如果要建树,“dfs”时对于每个父子关系建边就行

1
2
3
4
5
6
7
for (int i = 1, u, v; i <= tot; ++i) {
if (a[i] > 0) stk[++top] = a[i];
else {
v = stk[top--], u = stk[top];
Add_edge(u, v), Add_edge(v, u);
}
}

注意多次询问的一定要清空(在 a 序列上的点),不能用 \(O(n)\) 的 memset

code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
int tot, top, cnt_e;
int a[N<<1], stk[N];
bool vis[N];
inline void Add_edge(int u, int v) {/*...*/}
void dfs(int u, int last) {
pu[u] = ++dfu;
// 预处理LCA用的deep,fa...
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != last) {
dfs(v, u);
po[u] = ++dfu;
}
inline int LCA(int x, int y) {/*...*/}
inline bool cmp(int x, int y) {
return (x > 0 ? pu[x] : po[-x]) < (y > 0 ? pu[y] : po[-y]);
}
inline void build_tree(int k) {
tot = top = cnt_e = 0;
for (int i = 1; i <= k; ++i)
a[++tot] = array[i], vis[a[tot]] = true;
sort(a+1, a+tot+1, cmp);
for (int i = 1, lca; i < k; ++i) {
lca = LCA(a[i], a[i+1]);
if (!vis[lca]) a[++tot] = lca, vis[lca] = true;
}
if (!vis[1]) a[++tot] = 1;
k = tot;
for (int i = 1; i <= k; ++i)
a[++tot] = -a[i];
sort(a+1, a+tot+1, cmp);
for (int i = 1, u, v; i <= tot; ++i) {
if (a[i] > 0) stk[++top] = a[i];
else {
v = stk[top--], u = stk[top];
Add_edge(u, v), Add_edge(v, u);
vis[v] = false;
}
}
}

Luogu P2495 code

顺便放上去

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
using namespace std;

template<typename T>inline void read(T &x) {}
template<typename TT> inline void print(TT x, char end = '\n') {}

typedef long long LL;

const int N = 25e4 + 9;

struct Edge {
int nxt, to, val;
};

int n, m, cnt_e, dfu;
int head[N], pu[N], po[N], fa[N][20], deep[N], mn[N], a[N<<1], stk[N];
LL f[N];
bool vis[N];
Edge e[N<<1];

inline void add_edge(int u, int v, int w) {
e[++cnt_e] = (Edge){head[u], v, w}, head[u] = cnt_e;
}

void dfs(int u, int last) {
pu[u] = ++dfu;
deep[u] = deep[last] + 1;
fa[u][0] = last;
for (int i = 1; 1 << i <= deep[u]; ++i)
fa[u][i] = fa[fa[u][i-1]][i-1];
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != last) {
mn[v] = min(mn[u], e[i].val);
dfs(v, u);
}
po[u] = ++dfu;
}

int LCA(int x, int y) {
if (deep[x] < deep[y]) swap(x, y);
for (int i = 19; ~i; --i)
if (deep[fa[x][i]] >= deep[y])
x = fa[x][i];
for (int i = 19; ~i; --i)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return x == y ? x : fa[x][0];
}

inline bool cmp(int x, int y) {
return (x > 0 ? pu[x] : po[-x]) < (y > 0 ? pu[y] : po[-y]);
}

int main() {
read(n);
for (int i = 1, u, v, w; i < n; ++i)
read(u), read(v), read(w), add_edge(u, v, w), add_edge(v, u, w);
mn[1] = 0x7fffffff;
dfs(1, 0);
read(m);
int k, tot, top;
while (m--) {
tot = top = 0;
read(k);
for (int i = 1; i <= k; ++i)
read(a[++tot]), vis[a[tot]] = true, f[a[tot]] = mn[a[tot]];
sort(a+1, a+k+1, cmp);
for (int i = 1, lca; i < k; ++i){
lca = LCA(a[i], a[i+1]);
if (!vis[lca]) a[++tot] = lca, vis[lca] = true;
}
if (!vis[1]) a[++tot] = 1;
k = tot;
for (int i = 1; i <= k; ++i)
a[++tot] = -a[i];
sort(a+1, a+tot+1, cmp);
for (int i = 1, u; i <= tot; ++i) {
if (a[i] > 0) stk[++top] = a[i];
else {
u = stk[top--];
if (u != 1)
f[stk[top]] += min(f[u], (LL)mn[u]);
else print(f[u]);
f[u] = vis[u] = 0;
}
}
}
return 0;
}