点分治

题目描述

给定一棵有$n$个点的树

询问树上距离为$k$的点对是否存在。

输入输出格式

输入格式:

$n$,$m$ 接下来$n-1$条边$a$,$b$,$c$描述$a$到$b$有一条长度为c的路径

接下来$m$行每行询问一个$K$

输出格式:

对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)

输入输出样例

输入样例#1:

1
2
3
2 1
1 2 2
2

输出样例#1:

1
AYE

说明

对于$30\%$的数据$n\leq100$

对于$60\%$的数据$n\leq1000$,$m\leq50$

对于$100\%$的数据$n\leq10000$,$m\leq100$,$c\leq1000$,$K\leq10000000$

题解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include<bits/stdc++.h>
#define ps puts("")
#define fi first
#define nd second
#define mset(x) memset((x), 0, sizeof (x))
#define mk make_pair
#define sqr(x) ((x)*(x))
using namespace std;
typedef long long ll;
ll read() {ll x = 0;char f = 1, ch = getchar();while(ch < '0' || ch > '9') {if(ch == '-')f = -1;
ch = getchar();}while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}return x * f;}
void write(ll x) {if(x < 0) x = -x, putchar('-');if(x > 9) write(x / 10);putchar(x % 10 + '0');}
inline void writeln(ll x) {write(x);puts("");}
const int N = 32000;
int n, m;
int ver[N], nxt[N], w[N], en, head[N];
void add(int x, int y, int z) {
ver[++en] = y, nxt[en] = head[x], head[x] =en,w[en] = z;
}
int q[210], rt, mx[N], siz[N], sum;
bool vis[N];
void getrt(int x, int F) {
siz[x] = 1;
for(int i = head[x]; i;i = nxt[i]) {
int y = ver[i];
if(y == F || vis[y]) continue;
getrt(y,x);
siz[x] += siz[y];
mx[x] = max(mx[x], siz[y]);
}
mx[x] = max(mx[x], sum - siz[x]);
if(mx[x] < mx[rt]) rt = x;
}
bool ju[10001000];
int rem[N], dis[N], ans[N];
void getdis(int x, int F) {
rem[++rem[0]] = dis[x];
for(int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == F || vis[y]) continue;
dis[y] = dis[x] + w[i];
getdis(y,x);
}
}
int tmp[N];
void cal(int x) {
int p = 0;
for(int i = head[x]; i;i = nxt[i]) {
int y = ver[i];
if(vis[y]) continue;
rem[0] = 0; dis[y] = w[i];
getdis(y, x);
for(int j = rem[0]; j; --j)
for(int k = 1; k <= m; ++k)
if(q[k] >= rem[j])
ans[k] |= ju[q[k] - rem[j]];
for(int j = rem[0]; j; --j)
tmp[++p] = rem[j], ju[rem[j]] = 1;
}
for(int i = 1; i <= p; ++i)
ju[tmp[i]] = 0;
}
void solve(int x) {
vis[x] = ju[0] = 1;cal(x);
for(int i = head[x]; i;i = nxt[i]) {
int y = ver[i];
if(vis[y]) continue;
sum = siz[y]; rt = 0;
getrt(y, 0); solve(y);
}
}

int main() {
n = read(), m = read();
for(int i = 1; i < n; ++i) {
int x = read(), y = read(), z = read();
add(x, y, z);
add(y, x, z);
}
sum = mx[rt] = n;
for(int i = 1; i <= m; ++i)
q[i] = read();
getrt(1, 0);
solve(rt);
for(int i = 1; i <= m; ++i)
puts(ans[i] ? "AYE" : "NAY");
return 0;
}