0%

总结 树链剖分

咕了很久,敲完才体会到什么是“出题人毫无意义地强行把代码增加5KB”


首先,树剖是把一棵树划分成多条轻重链,然后用线段树维护这些链

模板题

通过基础的树剖,可以做以下操作:

  • 将树从x到y最短路径上的权值都加上z
  • 求树从x到y最短路径上的权值和
  • 将以x为根节点的子树内权值都加上z
  • 求将以x为根节点的子树内的权值和

具体做法:

dfs1

遍历一遍树,求出每个点的:

\(fa[x]\):父亲

\(deep[x]\):深度

\(sz[x]\):子树大小

\(son[x]\):重儿子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void dfs1(int x, int last) {
fa[x] = last;
deep[x] = deep[last] + 1;
sz[x] = 1;
int y, maxson = -1;
for (int i = head[x]; i; i = e[i].nxt)
if (e[i].to != last) {
y = e[i].to;
dfs1(y, x);
sz[x] += sz[y];
if (sz[y] > maxson)
maxson = sz[y], son[x] = y;
}
}

dfs2

第二遍遍历要划分轻重链,先求出:

\(id[x]\):x的\(dfs\)

\(top[x]\):x所在链的起始点

注意先走重儿子,再遍历轻儿子,使得重链每个点的\(dfs\)序一定是连续的

根据\(dfs\)序的性质,x的子树每个点的\(dfs\)序也是连续的

因此可以用线段树维护得到的\(dfs\)序(记得把权值转移到\(dfs\)序上)

1
2
3
4
5
6
7
8
9
10
11
void dfs2(int x, int top_) {
id[x] = ++cnt;
top[x] = top_;
segtree.val[id[x]] = val[x];// 转移权值
if (!son[x])
return;
dfs2(son[x], top_);// 先重儿子
for (int i = head[x]; i; i = e[i].nxt)
if (e[i].to != fa[x] && e[i].to != son[x])
dfs2(e[i].to, e[i].to);// 轻儿子新开一条链
}

线段树

直接在\(dfs\)序上建立线段树,套模板即可

这里就不贴code了

update on 2020.1.21

听学长讲,对于每个链单独开一棵线段树可以减小常数


下面是愉快的各种操作。。。

路径修改/查询

对于每个x和y

我们可以不停的让深度大的跳到所在链的顶部,在线段树上直接操作一个链

直到x和y在同一个链上,然后还是线段树操作

(注意,让深度小的往上跳,可能会错过最短路径)

1
2
3
4
5
6
7
8
9
10
11
12
13
inline void change_road(int x, int y, int k) {
if (deep[x] < deep[y])
swap(x, y);// 选深度大的
while (top[x] != top[y]) {// 不在同一条链时
if (deep[top[x]] < deep[top[y]])
swap(x, y);//注意深度
segtree.change(1, 1, n, id[top[x]], id[x], k);// 注意顺序,链顶的id一定大于x的id
x = fa[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);// 注意顺序
segtree.change(1, 1, n, id[x], id[y], k);
}

路径查询同理,就不贴了

话说这里太容易出bug了 QAQ

子树修改/查询

因为以x为根的子树在\(dfs\)序上一定是连续的一段

线段树直接操作\(id[x]\)\(id[x]+sz[x]-1\)的区间

segtree.change(1, 1, n, id[x], id[x]+sz[x]-1, z);

查询同理


LuoguP3384 AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#include <cstdio>
#include <cctype>

using namespace std;

namespace BANANA {

template<typename T> inline void read(T &x) {
x = 0; T k = 1; char in = getchar();
while (!isdigit(in)) { if (in == '-') k = -1; in = getchar(); }
while (isdigit(in)) x = x * 10 + in - '0', in = getchar();
x *= k;
}

const int N = 1e5 + 5;

struct Edge {
int nxt, to;
};

struct Segtree {
struct Leave {
int ls, rs, sum, lazy;
};
Leave tr[N<<2];
int mod, cnt;
int val[N];

inline void push_up(int suc) {
tr[suc].sum = (tr[tr[suc].ls].sum + tr[tr[suc].rs].sum) % mod;
}

inline void push_down(int suc, int L, int R) {
int ls = tr[suc].ls, rs = tr[suc].rs, mid = (L + R) >> 1;
(tr[ls].lazy += tr[suc].lazy) %= mod;
(tr[ls].sum += (mid - L + 1) * tr[suc].lazy) %= mod;
(tr[rs].lazy += tr[suc].lazy) %= mod;
(tr[rs].sum += (R - mid) * tr[suc].lazy) %= mod;
tr[suc].lazy = 0;
}

void build(int suc, int L, int R) {
tr[suc].lazy = 0;
if (L == R) {
tr[suc].sum = val[L];
return;
}
int mid = (L + R) >> 1;
tr[suc].ls = ++cnt, tr[suc].rs = ++cnt;
build(tr[suc].ls, L, mid), build(tr[suc].rs, mid+1, R);
push_up(suc);
}

void change(int suc, int L, int R, int cl, int cr, int k) {
if (cl <= L && R <= cr) {
(tr[suc].sum += (R - L + 1) * k) %= mod;
(tr[suc].lazy += k) %= mod;
return;
}
push_down(suc, L, R);
int mid = (L + R) >> 1;
if (cl <= mid)
change(tr[suc].ls, L, mid, cl, cr, k);
if (cr > mid)
change(tr[suc].rs, mid+1, R, cl, cr, k);
push_up(suc);
}

int query(int suc, int L, int R, int ql, int qr) {
if (ql <= L && R <= qr)
return tr[suc].sum;
push_down(suc, L, R);
int mid = (L + R) >> 1;
if (qr <= mid)
return query(tr[suc].ls, L, mid, ql, qr);
if (ql > mid)
return query(tr[suc].rs, mid+1, R, ql, qr);
return (query(tr[suc].ls, L, mid, ql, qr) + query(tr[suc].rs, mid+1, R, ql, qr)) % mod;
}
};

int n, Q, root, mod, cnt;
int head[N], val[N], deep[N], sz[N], fa[N], son[N], top[N], id[N];
Edge e[N<<1];
Segtree segtree;

inline void add(int u, int v) {
e[++cnt] = (Edge){head[u], v}, head[u] = cnt;
}

inline void swap(int &a, int &b) {
int t = a;
a = b, b = t;
}

void dfs1(int x, int last) {
fa[x] = last;
deep[x] = deep[last] + 1;
sz[x] = 1;
int y, maxson = -1;
for (int i = head[x]; i; i = e[i].nxt)
if (e[i].to != last) {
y = e[i].to;
dfs1(y, x);
sz[x] += sz[y];
if (sz[y] > maxson)
maxson = sz[y], son[x] = y;
}
}

void dfs2(int x, int top_) {
id[x] = ++cnt;
top[x] = top_;
segtree.val[id[x]] = val[x];
if (!son[x])
return;
dfs2(son[x], top_);
for (int i = head[x]; i; i = e[i].nxt)
if (e[i].to != fa[x] && e[i].to != son[x])
dfs2(e[i].to, e[i].to);
}

inline void change_road(int x, int y, int k) {
if (deep[x] < deep[y])
swap(x, y);
while (top[x] != top[y]) {
if (deep[top[x]] < deep[top[y]])
swap(x, y);
segtree.change(1, 1, n, id[top[x]], id[x], k);
x = fa[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);
segtree.change(1, 1, n, id[x], id[y], k);
}

inline int query_road(int x, int y) {
if (deep[x] < deep[y])
swap(x, y);
int res = 0;
while (top[x] != top[y]) {
if (deep[top[x]] < deep[top[y]])
swap(x, y);
(res += segtree.query(1, 1, n, id[top[x]], id[x])) %= mod;
x = fa[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);
(res += segtree.query(1, 1, n, id[x], id[y])) %= mod;
return res;
}

inline void main() {
read(n), read(Q), read(root), read(mod);
for (int i = 1; i <= n; ++i)
read(val[i]);
for (int i = 1, u, v; i < n; ++i)
read(u), read(v), add(u, v), add(v, u);
dfs1(root, 0);
cnt = 0;
dfs2(root, root);
segtree.mod = mod, segtree.cnt = 1;
segtree.build(1, 1, n);

int opt, x, y, z;
while (Q--) {
read(opt);
switch (opt) {
case 1:
read(x), read(y), read(z);
change_road(x, y, z);
break;
case 2:
read(x), read(y);
printf("%d\n", query_road(x, y));
break;
case 3:
read(x), read(y);
segtree.change(1, 1, n, id[x], id[x]+sz[x]-1, y);
break;
case 4:
read(x);
printf("%d\n", segtree.query(1, 1, n, id[x], id[x]+sz[x]-1));
break;
}
}
}
}

int main() {
BANANA::main();
return 0;
}

第一次写的时候太艰辛了,调了半天发现是\(swap()\)写错了\(qwq\)