0%

题解 P3233 [HNOI2014]世界树

题目

给一棵树,每条边距离为1,q 次询问,每次选择 k 个关键点,树上每个点由距离最近的关键的管辖(距离相同选择编号最小的),求每个关键的管辖点数

\(N,q\le3\times10^5,\sum_{i=1}^q{k_i}\le3\times10^5\)

思路

树形dp,看数据范围,需要建虚树来优化

考虑每次把关键点建出虚树

dp求出每个点 u 的 \(belong_u\)(管辖 u 的关键点),\(dis_u\) (u 到 \(belong_u\) 的距离)

类似最短路的松弛,注意第一遍dp统计儿子对父亲的贡献,第二遍统计父亲对儿子的贡献

1
2
if (dis[u] > dis[v] + E[i].val || (dis[u] == dis[v] + E[i].val && bl[u] > bl[v]))
dis[u] = dis[v] + E[i].val, bl[u] = bl[v];

第三遍dp统计答案,设 \(f_u\) 表示原树中经过 u 增加 \(belong_u\) 贡献的点数

在虚树上有两种情况:

  • 以 u 为根的原树的子树中没有关键点,那么这棵子树都由 u 或 \(belong_u\) 管辖
  • 虚树上连接 u 和 v 的边(u 为 v 的父亲),代表原树中的一条链,又分两种情况:
    • \(belong_u=belong_v\),这一条链除了 v 点其它都是 u 的贡献(v 点及其子树为 v 的贡献)
    • \(belong_u\not=belong_v\),这条链被分成两部分,通过倍增找出分界点,划分 u,v 贡献

把 u 贡献加到 \(belong_u\) 的答案上即可

细节

这题一堆全局变量数组(命名冲突好麻烦),比大数据结构难调

关于链上找分界点(属于 v 范围的最高点):

先求出 \(belong_u\)\(belong_v\) 的距离 d,deep[v] - (d / 2 - dis[v]) 即中间点的 deep

如果 u,v 中间有奇数个点,必定有一个点 x 到两个关键点的距离相等,要让 d - 1,倍增后中间点为 x 的儿子(对 d 为偶数没有影响),看两个关键点编号大小选择是否往上走

1
2
3
4
5
int d = deep[v] - deep[u] + dis[u] + dis[v] - 1, tmp = d / 2 - dis[v], mid = v;
for (int j = 19; ~j; --j)
if (deep[fa[mid][j]] >= deep[v] - tmp)
mid = fa[mid][j];
if ((d & 1) && tmp >= 0 && bl[u] > bl[v]) mid = fa[mid][0];// tmp 可能为负

另外多测要清空

code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <cstdio>
#include <cctype>
#include <algorithm>
using namespace std;

template<typename T> inline void read(T &x) {}
template<typename TT> inline void print(TT x, char end = '\n') {}

const int N = 3e5 + 9;
const int INF = 0x3f3f3f3f;

struct Edge {
int nxt, to, val;
};

int n, cnt_e, Q, k, dfu, tot, top;
int head[N], Head[N], pu[N], po[N];
int deep[N], fa[N][20], size[N], a[N], t[N<<1], stk[N];
int dis[N], bl[N], f[N], ans[N];
Edge e[N<<1], E[N<<1];
bool is[N];

inline void add_edge(int u, int v, int w) {
e[++cnt_e] = (Edge){head[u], v, w}, head[u] = cnt_e;
}
inline void Add_edge(int u, int v, int w) {
E[++cnt_e] = (Edge){Head[u], v, w}, Head[u] = cnt_e;
}

void dfs(int u, int last) {
pu[u] = ++dfu;
deep[u] = deep[last] + 1;
fa[u][0] = last;
size[u] = 1;
for (int i = 1; 1 << i <= deep[u]; ++i)
fa[u][i] = fa[fa[u][i-1]][i-1];
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != last)
dfs(v, u), size[u] += size[v];
po[u] = ++dfu;
}

int LCA(int x, int y) {
if (deep[x] < deep[y]) swap(x, y);
for (int i = 19; ~i; --i)
if (deep[fa[x][i]] >= deep[y])
x = fa[x][i];
for (int i = 19; ~i; --i)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return x == y ? x : fa[x][0];
}

inline bool cmp(int x, int y) {
return (x > 0 ? pu[x] : po[-x]) < (y > 0 ? pu[y] : po[-y]);
}

inline void build_tree(int kk) {
static bool vis[N];
tot = top = cnt_e = 0;
for (int i = 1; i <= kk; ++i)
t[++tot] = a[i], vis[t[tot]] = true;
sort(t+1, t+tot+1, cmp);
for (int i = 1, lca; i < kk; ++i) {
lca = LCA(t[i], t[i+1]);
if (!vis[lca]) t[++tot] = lca, vis[lca] = true;
}
if (!vis[1]) t[++tot] = 1;
kk = tot;
for (int i = 1; i <= kk; ++i)
t[++tot] = -t[i];
sort(t+1, t+tot+1, cmp);
for (int i = 1, u, v; i <= tot; ++i) {
if (t[i] > 0) stk[++top] = t[i];
else {
v = stk[top--], u = stk[top];
Add_edge(u, v, deep[v] - deep[u]), Add_edge(v, u, deep[v] - deep[u]);
vis[v] = false;
}
}
}

void dp1(int u, int last) {
if (is[u]) dis[u] = 0, bl[u] = u;
else dis[u] = INF;
for (int i = Head[u], v; i; i = E[i].nxt)
if ((v = E[i].to) != last) {
dp1(v, u);
if (dis[u] > dis[v] + E[i].val || (dis[u] == dis[v] + E[i].val && bl[u] > bl[v]))
dis[u] = dis[v] + E[i].val, bl[u] = bl[v];
}
}

// 我把第二三遍dp合在一块了
void dp2(int u, int last) {
f[u] = size[u];// 初值为子树size
for (int i = Head[u], v; i; i = E[i].nxt)
if ((v = E[i].to) != last) {
if (dis[v] > dis[u] + E[i].val || (dis[v] == dis[u] + E[i].val && bl[v] > bl[u]))
dis[v] = dis[u] + E[i].val, bl[v] = bl[u];
dp2(v, u);
if (bl[u] == bl[v])
f[u] -= size[v];
else {
int d = deep[v] - deep[u] + dis[u] + dis[v] - 1, tmp = d / 2 - dis[v], mid = v;
for (int j = 19; ~j; --j)
if (deep[fa[mid][j]] >= deep[v] - tmp)
mid = fa[mid][j];
if ((d & 1) && tmp >= 0 && bl[u] > bl[v]) mid = fa[mid][0];// tmp可能为负数
f[u] -= size[mid];
f[v] += size[mid] - size[v];
}
ans[bl[v]] += f[v];
}
if (u == 1) ans[bl[1]] += f[1];// 别落下根节点
}

inline void main() {
read(n);
for (int i = 1, u, v; i < n; ++i)
read(u), read(v), add_edge(u, v, 0), add_edge(v, u, 0);
dfs(1, 0);
read(Q);
while (Q--) {
read(k);
for (int i = 1; i <= k; ++i)
read(a[i]), is[a[i]] = true;
build_tree(k);
dp1(1, 0), dp2(1, 0);
for (int i = 1; i <= k; ++i)
print(ans[a[i]], ' '), is[a[i]] = ans[a[i]] = 0;
puts("");
// 清空,用memset会TLE
for (int i = 1; i <= tot; ++i)
if (t[i] > 0)
Head[t[i]] = dis[t[i]] = bl[t[i]] = 0;
}
return 0;
}