CF1616G Just Add an Edge

发布于 2022-02-22  2.9k 次阅读


题目链接

题目大意

给出一个 $n$ 个点的 DAG,满足 $1,~2,~3,~\dots,~n$ 为它的一个拓扑序。问有多少对满足条件的二元组 $(x,~y)$ 满足 $x < y$ 且在原图上加入一条 $y \to x$ 的有向边后原图存在哈密顿路径。

$n,~m \le 1.5 \times 10^5$

分析

若原图本身存在哈密顿路径,则答案即为 $n \choose 2$。

否则我们可以从原图上找出两条链,满足每个点均在其中一条链上,将这两条链首尾相连即可得到哈密顿路径。令最终我们加入的边为 $y \to x$,则哈密顿路径应形如:

  1. $1 \to 2 \to 3 \to \cdots \to (x - 1)$
  2. $(x - 1) \to \cdots \to y$
  3. $y \to x$
  4. $x \to \cdots \to (y + 1)$
  5. $(y + 1) \to (y + 2) \to \cdots \to n$

其中 2、4 两部分包括 $(x-1) \sim (y + 1)$ 之间的所有结点,且 $(x - 1)$ 与 $y$ 在同一条链上,$x$ 与 $(y + 1)$ 在另一条链上。若采用状态 $(x,~y)$ 表示 $x$ 与 $1$ 在同一条链上,$y$ 与 $n$ 在另一条链上,当 $(x-1,~x)$ 和 $(y,~y+1)$ 状态在同一方案中同时存在时连接 $y \to x$ 为一种合法方案。

将所有结点 $a$ 与结点 $(a + 1)$ 所在链不同的情况拿出,令 $S_x = (x,~x + 1),~T_x = (x + 1,~x)$。将这 $O(n)$ 个状态看为结点,在同时出现于同一方案中的相邻状态对应的结点间连边。不难发现相邻两状态的关系一定形如这样:

即在 $(x + 1)$ 可向右以步长 $1$ 走到 $y$ 且存在边 $x \to (y + 1)$ 时 $S_x = T_y$。通过此方法可以在这 $O(n)$ 个状态间连上 $O(m)$ 条边。最终答案即为:满足下面条件的数对 $(x,~y)$ 数量:

  1. 可从 $1$ 号点以步长 $1$ 走到 $x$。
  2. 可从 $(y + 1)$ 以步长 $1$ 走到 $n$。
  3. $S_x = S_y$

官方题解给出了一种巧妙地方法计算这样的数对的数量:由于原图中不存在哈密顿路径,所以必然存在至少一个点 $p$ 满足不存在边 $p \to (p + 1)$。不难发现对于每一对相连的 $S_x,~S_y$,它们两点之间的路径上必然存在 $S_p$ 或 $T_p$。因此建完图后从 $S_p$ 和 $T_p$ 开始分别向左和向右遍历连通块,将左侧连通块内满足条件 1 的结点数量与右侧连通块内满足条件 2 的结点相乘即可。由于可能存在 $S_x$ 与 $S_y$ 同时在 $S_p$ 和 $T_p$ 两侧的连通块中的情况,因此最终需容斥减去算重的部分。

时间复杂度 $O(n \log n)$。

代码

View on GitHub

/**
 * @file 1616G.cpp
 * @author Macesuted (i@macesuted.moe)
 * @date 2022-02-22
 *
 * @copyright Copyright (c) 2022
 * @brief
 *      My Tutorial: https://macesuted.moe/article/cf1616g
 *
 */

#include <bits/stdc++.h>
using namespace std;

#define MP make_pair
#define MT make_tuple

namespace io {
#define SIZE (1 << 20)
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55];
int f, qr;
inline void flush(void) { return fwrite(obuf, 1, oS - obuf, stdout), oS = obuf, void(); }
inline char getch(void) {
    return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS++)) : *iS++);
}
void putch(char x) {
    *oS++ = x;
    if (oS == oT) flush();
    return;
}
string getstr(void) {
    string s = "";
    char c = getch();
    while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = getch();
    while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) s.push_back(c), c = getch();
    return s;
}
void putstr(string str, int begin = 0, int end = -1) {
    if (end == -1) end = str.size();
    for (int i = begin; i < end; i++) putch(str[i]);
    return;
}
template <typename T>
T read() {
    T x = 0;
    for (f = 1, c = getch(); c < '0' || c > '9'; c = getch())
        if (c == '-') f = -1;
    for (x = 0; c <= '9' && c >= '0'; c = getch()) x = x * 10 + (c & 15);
    return x * f;
}
template <typename T>
void write(const T& t) {
    T x = t;
    if (!x) putch('0');
    if (x < 0) putch('-'), x = -x;
    while (x) qu[++qr] = x % 10 + '0', x /= 10;
    while (qr) putch(qu[qr--]);
    return;
}
struct Flusher_ {
    ~Flusher_() { flush(); }
} io_flusher_;
}  // namespace io
using io::getch;
using io::getstr;
using io::putch;
using io::putstr;
using io::read;
using io::write;

bool mem1;

#define maxn 150005

bool cons[maxn], vis[2][maxn * 2];

vector<vector<int>> graph, g, gr;

void dfs(int p, vector<vector<int>>& graph, int id) {
    vis[id][p] = true;
    for (auto i : graph[p])
        if (!vis[id][i]) dfs(i, graph, id);
    return;
}

void solve(void) {
    int n = read<int>() + 2, m = read<int>();
    graph.clear(), g.clear(), gr.clear(), graph.resize(n + 1), g.resize(2 * n + 1), gr.resize(2 * n + 1);
    for (int i = 1; i <= n; i++) cons[i] = false, vis[0][i] = vis[1][i] = vis[0][n + i] = vis[1][n + i] = false;
    for (int i = 3; i < n; i++) graph[1].push_back(i), graph[i - 1].push_back(n);
    cons[1] = cons[n - 1] = true;
    for (int i = 1; i <= m; i++) {
        int x = read<int>() + 1, y = read<int>() + 1;
        if (x + 1 == y)
            cons[x] = true;
        else
            graph[x].push_back(y);
    }
    int p = 0, l = 0, r = 0;
    for (int i = 1; i < n; i++)
        if (!cons[i]) {
            if (!p) p = l = i;
            r = i;
        }
    if (p == 0) return write(1LL * (n - 2) * (n - 3) / 2), putch('\n');
    for (int i = n, r = n; i > 1; i--) {
        if (!cons[i]) r = i;
        for (auto j : graph[i - 1])
            if (j <= r + 1)
                g[i - 1].push_back(n + j), g[n + i].push_back(j - 1), gr[n + j].push_back(i - 1), gr[j - 1].push_back(n + i);
    }
    dfs(p, g, 0), dfs(p, gr, 0), dfs(n + p + 1, g, 1), dfs(n + p + 1, gr, 1);
    long long ans = 0, cnt1 = 0, cnt2 = 0;
    for (int i = 1; i <= l; i++) cnt1 += vis[0][i];
    for (int i = r; i < n; i++) cnt2 += vis[0][i];
    ans += cnt1 * cnt2, cnt1 = cnt2 = 0;
    for (int i = 1; i <= l; i++) cnt1 += vis[1][i];
    for (int i = r; i < n; i++) cnt2 += vis[1][i];
    ans += cnt1 * cnt2, cnt1 = cnt2 = 0;
    for (int i = 1; i <= l; i++) cnt1 += (vis[0][i] & vis[1][i]);
    for (int i = r; i < n; i++) cnt2 += (vis[0][i] & vis[1][i]);
    ans -= cnt1 * cnt2;
    return write(ans - (l == r)), putch('\n');
}

bool mem2;

int main() {
#ifdef MACESUTED
    cerr << "Memory: " << abs(&mem1 - &mem2) / 1024. / 1024. << "MB" << endl;
#endif

    int _ = read<int>();
    while (_--) solve();

#ifdef MACESUTED
    cerr << "Time: " << clock() * 1000. / CLOCKS_PER_SEC << "ms" << endl;
#endif
    return 0;
}

我缓慢吐出一串啊吧啊吧并不再想说话