0%

题解 P4218 [CTSC2010]珠宝商

SAM + 点分治 + 根号分治。

很综合的一道题。

题目

给出 \(n\) 个点的树,每个点有一个字符,一条路径 \((u,v)\) 代表一个字符串,\(str(u,v)\neq str(v,u)\)。再给出一个串 \(S\)

\(count(T)\) 表示 \(T\)\(S\) 中出现次数,求 \(\sum_{u,v} count(str(u,v))\)

\(n,m\le 5\times 10^4\)

思路

暴力

我们先想一个 \(\mathcal O(n^2)\) 的暴力。

\(S\) 串建立 SAM,考虑把每个点 \(u\) 作为根,一边 dfs 一边在 SAM 上匹配,可以统计所有一端为 \(u\) 的路径的答案。

点分治

树上路径计数问题,可以想到点分治。

点分治就需要考虑链的合并,可以对于 \(S\) 和反串 \(S^r\) 分别建立 SAM,记为 \(A_0,A_1\)

每次分治,设重心 \(u\),将每一条以 \(u\) 开头的链 \(str(u,v)\)\(A_0,A_1\) 分别标记。

然后在 \(A_0,A_1\) 的 parent 树上下放所有标记,那么自动机上每一个点就记录了该点代表的子串包含的所有链作为后缀的个数。

统计答案就是在 \(S\) 串上枚举拼接点 \(i\),以 \(i\) 结尾的后缀和 \(i\) 开头的前缀都可以计入答案。

点分治一定要考虑容斥,实现的过程中,发现既要往一个子串后面插入(走 \(trans\)),又要往前面插入(走后缀树 \(ch\)),所以需要用 SAM 建出后缀树。

这样每次复杂度 \(\mathcal O(size+m)\)\(size\) 是每次分治连通块大小。

总复杂度 \(\mathcal O(n\log n+nm)\),不行啊。

根号分治

发现瓶颈在于,对于很小的连通块,我们都需要 \(\mathcal O(m)\) 遍历整个 SAM。

想起暴力做法复杂度与 \(m\) 无关,那么根号分治一下,小于 \(\sqrt{n}\) 的部分就暴力统计。

最后总复杂度好像是 \(\mathcal O((n+m)\sqrt{n})\)

代码

细节很多啊,调了三天。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
const int N = 2e5 + 9;

int n, m, root, all, block;
LL ans;
int val[N], size[N], maxs[N], fat[N];
bool vis[N];
char str[N];
vector<int> vec;

struct Graph {
int cnt;
int head[N], nxt[N], to[N];
inline void add(int u, int v) {
++cnt, nxt[cnt] = head[u], to[cnt] = v, head[u] = cnt;
}
} G;

struct SuffixAutomaton {
int tot;
// trans:自动机, ch:后缀树
int trans[N][26], ch[N][26], link[N], len[N], size[N], id[N], pos[N], tag[N], str[N];
Graph T;

SuffixAutomaton(): tot(1) {}
inline int insert(int c, int last, int ps) {
int p = last, np = ++tot;
len[np] = len[p] + 1, ++size[np], pos[np] = ps;
while (p && !trans[p][c]) trans[p][c] = np, p = link[p];
if (!p) link[np] = 1;
else {
int q = trans[p][c];
if (len[q] == len[p] + 1) link[np] = q;
else {
int nq = ++tot;
len[nq] = len[p] + 1, link[nq] = link[q], pos[nq] = pos[q];
memcpy(trans[nq], trans[q], sizeof trans[q]);
link[q] = link[np] = nq;
while (p && trans[p][c] == q) trans[p][c] = nq, p = link[p];
}
}
return np;
}
void dfs(int u) {
for (int i = T.head[u]; i; i = T.nxt[i]) {
int v = T.to[i];
dfs(v);
size[u] += size[v];
ch[u][str[pos[v] - len[u]]] = v;
pos[u] = max(pos[u], pos[v]);
}
}
inline void init() {
for (int i = 1, last = 1; i <= m; ++i)
id[i] = last = insert(str[i], last, i);
for (int i = 2; i <= tot; ++i) T.add(link[i], i);
dfs(1);
}
inline int nxt(int p, int l, int c) {
if (l <= len[p]) return str[pos[p] - l + 1] == c ? p : 0; // bug
return ch[p][c];
}
void mark(int u, int fa, int p, int l) {
if (!p) return;
++tag[p];
for (int i = G.head[u]; i; i = G.nxt[i]) {
int v = G.to[i];
if (vis[v] || v == fa) continue;
mark(v, u, nxt(p, l + 1, val[v]), l + 1);
}
}
void push_down(int u) {
for (int i = 0; i < 26; ++i) {
int v = ch[u][i];
if (!v) continue;
tag[v] += tag[u];
push_down(v);
}
}
inline void clear() {
for (int i = 1; i <= tot; ++i) tag[i] = 0;
}
} A0, A1;

void find_root(int u, int fa) {
size[u] = 1, maxs[u] = 0;
for (int i = G.head[u]; i; i = G.nxt[i]) {
int v = G.to[i];
if (vis[v] || v == fa) continue;
find_root(v, u);
size[u] += size[v];
maxs[u] = max(maxs[u], size[v]);
}
maxs[u] = max(maxs[u], all - size[u]);
if (!root || maxs[u] < maxs[root]) root = u;
}

void force_dfs1(int u, int fa) {
vec.push_back(u);
fat[u] = fa;
for (int i = G.head[u]; i; i = G.nxt[i]) {
int v = G.to[i];
if (vis[v] || v == fa) continue;
force_dfs1(v, u);
}
}

void force_dfs2(int u, int fa, int p, int op) {
if (!p) return;
ans += (LL)A0.size[p] * op;
for (int i = G.head[u]; i; i = G.nxt[i]) {
int v = G.to[i];
if (vis[v] || v == fa) continue;
force_dfs2(v, u, A0.trans[p][val[v]], op);
}
}

inline void force_add(int u) {
vec.clear();
force_dfs1(u, 0);
for (int i = 0; i < (int)vec.size(); ++i) {
int x = vec[i];
force_dfs2(x, 0, A0.trans[1][val[x]], 1);
}
}

inline void force_del(int u, int fa) {
vec.clear();
force_dfs1(u, fa);
for (int i = 0; i < (int)vec.size(); ++i) {
int x = vec[i], p = 1;
while (x != fa) p = A0.trans[p][val[x]], x = fat[x];
p = A0.trans[p][val[fa]];
force_dfs2(u, 0, A0.trans[p][val[u]], -1);
}
}

inline void calc(int u, int fa, int op) {
if (op == 1) {
A0.mark(u, 0, A0.trans[1][val[u]], 1);
A1.mark(u, 0, A1.trans[1][val[u]], 1);
}
else {
A0.mark(u, fa, A0.nxt(A0.trans[1][val[fa]], 2, val[u]), 2);
A1.mark(u, fa, A1.nxt(A1.trans[1][val[fa]], 2, val[u]), 2);
}
A0.push_down(1), A1.push_down(1);
for (int i = 1; i <= m; ++i)
ans += (LL)op * A0.tag[A0.id[i]] * A1.tag[A1.id[m - i + 1]]; // bug
A0.clear(), A1.clear();
}

void solve(int u) {
if (all <= block) return force_add(u);
vis[u] = true;
calc(u, 0, 1);
int now = all;
for (int i = G.head[u]; i; i = G.nxt[i]) {
int v = G.to[i];
if (vis[v]) continue;
all = size[v] > size[u] ? (now - size[u]) : size[v];
if (all <= block) force_del(v, u);
else calc(v, u, -1);
root = 0, find_root(v, u);
solve(root);
}
}

inline void main() {
cin >> n >> m;
block = sqrt(n);
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
G.add(u, v), G.add(v, u);
}
cin >> (str + 1);
for (int i = 1; i <= n; ++i) val[i] = str[i] - 'a';

cin >> (str + 1);
for (int i = 1; i <= m; ++i)
A0.str[i] = A1.str[m - i + 1] = str[i] - 'a';
A0.init(), A1.init();

all = n, root = 0;
find_root(1, 0), solve(root);
cout << ans << '\n';
}