HDR002B 病毒

发布于 2021-09-20  1.22k 次阅读


题目链接

题意

给出一个 $n$ 个点的树,然后有 $q$ 次询问,每次询问给出一个大小为 $k_i$ 的点集,问树上是否存在一个点使得该点到点集中任意一点的距离均相同。

$n \le 10^6,~1 \le 5 \times 10^5,~\sum k_i \le 10^6$

分析

令我们要找的这个点到点集中任意一个点的距离为 $dist$,根据性质不难发现,若存在解,则点集中的最远点对(即从点集中取出不相同两个点能得到的最远距离)的长度必然为 $dist \times 2$。

证明如下。令最远点对为 $(x,~y)$,则我们先将 $x$ 和 $y$ 加入图中,然后依次加入点集中的其他点。在只加入 $x$ 和 $y$ 时当前答案点为它们之间的链的中点 $z$。在加入点集中的其他点时,如果要让答案点离开 $z$ 到达其他点,只有可能新加入的点到 $z$ 的距离大于 $z$ 与 $x,~y$ 之间的距离。但是由于 $(x,~y)$ 为最远点对,所以无法找到这样的点能使得答案点离开 $z$。因此如果存在答案则答案一定为 $z$。

因此对于给出的点集,我们只需要用求直径的方式先求出点集中的最远点对,并将最远点对之间的链上的中点拿出作为答案,并检验点集中其他点到答案点的距离是否都相同。如果都满足条件,则此点即为我们要求的答案点;如果存在点不满足条件或是最远点对之间的距离为奇数,则无解。

时间复杂度 $O(\sum k \times \log n)$,可能存在轻度卡常。

代码

View on GitHub

/**
 * @author Macesuted (i@macesuted.moe)
 * @copyright Copyright (c) 2021
 * @brief
 *      My solution: https://macesuted.cn/article/hdr002b/
 */

#include <bits/stdc++.h>
using namespace std;

namespace io {
#define SIZE (1 << 20)
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55];
int f, qr;
inline void flush(void) { return fwrite(obuf, 1, oS - obuf, stdout), oS = obuf, void(); }
inline char getch(void) { return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS++)) : *iS++); }
inline void putch(char x) {
    *oS++ = x;
    if (oS == oT) flush();
    return;
}
string getstr(void) {
    queue<char> que;
    char c = getch();
    while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = getch();
    while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) que.push(c), c = getch();
    string s;
    s.resize(que.size());
    for (register int i = 0; i < (int)s.size(); i++) s[i] = que.front(), que.pop();
    return s;
}
void putstr(string str, int begin = 0, int end = -1) {
    if (end == -1) end = str.size();
    for (register int i = begin; i < end; i++) putch(str[i]);
    return;
}
template <typename T>
inline T read() {
    register T x = 0;
    for (f = 1, c = getch(); c < '0' || c > '9'; c = getch())
        if (c == '-') f = -1;
    for (x = 0; c <= '9' && c >= '0'; c = getch()) x = x * 10 + (c & 15);
    return x * f;
}
template <typename T>
inline void write(const T& t) {
    register T x = t;
    if (!x) putch('0');
    if (x < 0) putch('-'), x = -x;
    while (x) qu[++qr] = x % 10 + '0', x /= 10;
    while (qr) putch(qu[qr--]);
    return;
}
struct Flusher_ {
    ~Flusher_() { flush(); }
} io_flusher_;
}  // namespace io
using io::getch;
using io::getstr;
using io::putch;
using io::putstr;
using io::read;
using io::write;

#define maxn 1000005
#define maxlgn 21

vector<vector<int>> graph;

int dep[maxn], siz[maxn], son[maxn], fa[maxn][maxlgn], top[maxn];
int a[maxn];

void dfs(int p, int pre) {
    dep[p] = dep[pre] + 1;
    siz[p] = 1;
    fa[p][0] = pre;
    for (register int i = 1; i < maxlgn; i++) fa[p][i] = fa[fa[p][i - 1]][i - 1];
    for (vector<int>::iterator i = graph[p].begin(); i != graph[p].end(); i++)
        if (*i != pre) {
            dfs(*i, p);
            siz[p] += siz[*i];
            if (siz[*i] > siz[son[p]]) son[p] = *i;
        }
    return;
}
void dfs1(int p, int ftop) {
    top[p] = ftop;
    if (!son[p]) return;
    dfs1(son[p], ftop);
    for (vector<int>::iterator i = graph[p].begin(); i != graph[p].end(); i++)
        if (*i != son[p] && *i != fa[p][0]) dfs1(*i, *i);
    return;
}
int LCA(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]][0];
    }
    return dep[x] < dep[y] ? x : y;
}
inline int dist(int x, int y) { return dep[x] + dep[y] - (dep[LCA(x, y)] << 1); }
int jump(int p, int step) {
    int t = 0;
    while (step) {
        if (step & 1) p = fa[p][t];
        step >>= 1, t++;
    }
    return p;
}

int main() {
    int n = read<int>();
    graph.resize(n + 1);
    for (register int i = 1, from, to; i < n; i++) {
        from = read<int>(), to = read<int>();
        graph[from].push_back(to), graph[to].push_back(from);
    }
    dfs(1, 0), dfs1(1, 1);
    int q = read<int>();
    while (q--) {
        int k = read<int>();
        for (register int i = 1; i <= k; i++) a[i] = read<int>();
        int node1 = a[1];
        for (register int i = 2; i <= k; i++)
            if (dep[a[i]] > dep[node1]) node1 = a[i];
        int dis = 0, node2 = node1;
        for (register int i = 1; i <= k; i++) {
            register int cdis = dist(node1, a[i]);
            if (cdis > dis) dis = cdis, node2 = a[i];
        }
        if (dis & 1) {
            write(-1), putch('\n');
            continue;
        }
        int mid = jump(dep[node1] > dep[node2] ? node1 : node2, dis >>= 1);
        bool check = true;
        for (register int i = 1; i <= k; i++)
            if (dist(a[i], mid) != dis) {
                write(-1), putch('\n');
                check = false;
                break;
            }
        if (check) write(mid), putch(' '), write(dis), putch('\n');
    }
    return 0;
}

我缓慢吐出一串啊吧啊吧并不再想说话