H1034 [Ynoi2017] 由乃的 OJ

发布于 2021-07-19  3.74k 次阅读


题面

Statement

给定一棵 $n$ 个点的树,每个节点上均有一个二进制运算符(&|^)和一个在 $[0,~2^k-1]$ 之间的数值。

接下来给定 $m$ 个操作,每个操作分为两类:

  • 修改某个结点上的运算符与数值。
  • 给定 $x,~y,~z$,你可以先任意指定一个值 $val \in [0,~z]$,然后在树上沿 $x$ 与 $y$ 之间的简单路径从 $x$ 向 $y$ 移动,每次到达一个结点后将 $val$ 变为 $val~op[i]~a[i]$,$op[i]$ 为该节点上运算符,$a[i]$ 为该节点上数值。最大化到达 $y$ 结点时的 $val$ 值并输出。

$n,~m \le 10^5,~0 \le k \le 64$

Solution

由于询问时询问的是树上两点间路径的信息,不难想到通过树链剖分将每条路径转化为 $\log n$ 个区间。

由于每个结点上的运算符均为二进制位运算符,运算过程中不同二进制位之间互不影响。我们可以考虑对于每一个二进制位分开来考虑其运算结果,即对于每一个二进制位都求出其为 $0/1$ 时经过路径后的结果,最后通过类似数位 DP 的计算方法即可求出 $[0,~z]$ 区间内的初值可产生的最大运算结果。

而由于询问时两点间的路径是有向的,从 $x \to lca$ 的路径是向上的,从 $lca \to y$ 的路径是向下的。将树上移动的顺序对应到区间上移动的顺序,不难发现 $x \to lca$ 的路径对应的区间都是从右向左经过的,$lca \to y$ 的路径对应的区间都是从左向右经过的。因此我们需要对于每一个区间都求出每一个二进制位初始为 $0/1$ 时从左向右或是从右向左经过该区间之后的值。

考虑如何维护,建立一棵线段树,每个线段树结点都记录 $l0[i],~l1[i],~r0[i],~r1[i]$ 分别表示:

  • 二进制第 $i$ 位为 $0$ 时从左向右经过该区间后该位的值。
  • 二进制第 $i$ 位为 $1$ 时从左向右经过该区间后该位的值。
  • 二进制第 $i$ 位为 $0$ 时从右向左经过该区间后该位的值。
  • 二进制第 $i$ 位为 $1$ 时从右向左经过该区间后该位的值。

每次合并两个结点 $a,~b$ 的信息时令:

  • ans.l0[i] = a.l0[i] && b.l1[i] || !a.l0[i] && b.l0[i]
  • ans.l1[i] = a.l1[i] && b.l1[i] || !a.l1[i] && b.l0[i]
  • ans.r0[i] = b.r0[i] && a.r1[i] || !b.r0[i] && a.r0[i]
  • ans.r1[i] = b.r1[i] && a.r1[i] || !b.r1[i] && a.r0[i]

即枚举某一位在经过第一个区间后的值,将其以初值代入第二个区间得到结果。

此时对于每一个询问我们将路径拆为 $\log n$ 个区间,每个区间对应 $\log n$ 个线段树结点,每次合并结点需要花费 $O(k)$ 的时间,因此此时的总时间复杂度为 $O(m \times k \times \log ^2n)$,无法通过本题。

Optimization

我们仔细分析上面的时间复杂度,发现两个 $\log n$ 在该算法中都难以去除,因此我们考虑优化掉 $O(k)$ 的时间复杂度。容易发现 $O(k)$ 的时间复杂度来自线段数结点合并。

仔细观察我们发现对于 $k$ 位分别进行逻辑运算是非常浪费的,我们考虑将四个大小为 $k$ 的数组压为四个大小为 $2^k$ 的数字,将逻辑运算转变为二进制位运算即可。对应的四个转移变为:

  • ans.l0 = (a.l0 & b.l1) | (~a.l0 & b.l0)
  • ans.l1 = (a.l1 & b.l1) | (~a.l1 & b.l0)
  • ans.r0 = (b.r0 & a.r1) | (~b.r0 & a.r0)
  • ans.r1 = (b.r1 & a.r1) | (~b.r1 & a.r0)

因此合并结点信息的复杂度优化到 $O(1)$,总复杂度达到 $O(m \times \log ^2n)$,足以通过此题。

Code

View on GitHub

/**
 * @author Macesuted (i@macesuted.moe)
 * @copyright Copyright (c) 2021
 * @brief 
 *      My Solution: https://macesuted.moe/article/h1034
 */

#include <bits/stdc++.h>
using namespace std;

namespace io {
#define SIZE (1 << 20)
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55];
int f, qr;
inline void flush(void) { return fwrite(obuf, 1, oS - obuf, stdout), oS = obuf, void(); }
inline char getch(void) { return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS++)) : *iS++); }
inline void putch(char x) {
    *oS++ = x;
    if (oS == oT) flush();
    return;
}
string getstr(void) {
    string s = "";
    char c = getch();
    while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = getch();
    while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) s.push_back(c), c = getch();
    return s;
}
void putstr(string str, int begin = 0, int end = -1) {
    if (end == -1) end = str.size();
    for (register int i = begin; i < end; i++) putch(str[i]);
    return;
}
template <typename T>
inline T read() {
    register T x = 0;
    for (f = 1, c = getch(); c < '0' || c > '9'; c = getch())
        if (c == '-') f = -1;
    for (x = 0; c <= '9' && c >= '0'; c = getch()) x = x * 10 + (c & 15);
    return x * f;
}
template <typename T>
inline void write(const T& t) {
    register T x = t;
    if (!x) putch('0');
    if (x < 0) putch('-'), x = -x;
    while (x) qu[++qr] = x % 10 + '0', x /= 10;
    while (qr) putch(qu[qr--]);
    return;
}
struct Flusher_ {
    ~Flusher_() { flush(); }
} io_flusher_;
}  // namespace io
using io::getch;
using io::getstr;
using io::putch;
using io::putstr;
using io::read;
using io::write;

#define maxn 100005

class SegmentTree {
   public:
    struct Node {
        unsigned long long l0, l1, r0, r1;
        Node(void) { l0 = 0, l1 = ~0, r0 = 0, r1 = ~0; }
        Node operator+(const Node& b) const {
            Node a = *this, ans;
            ans.l0 = (a.l0 & b.l1) | (~a.l0 & b.l0);
            ans.l1 = (a.l1 & b.l1) | (~a.l1 & b.l0);
            ans.r0 = (b.r0 & a.r1) | (~b.r0 & a.r0);
            ans.r1 = (b.r1 & a.r1) | (~b.r1 & a.r0);
            return ans;
        }
    };
    Node tree[maxn << 2];
    int n;

    void update(int p, int l, int r, int qp, int opt, unsigned long long val) {
        if (l == r) {
            if (opt == 1)
                tree[p].l0 = tree[p].r0 = 0, tree[p].l1 = tree[p].r1 = val;
            else if (opt == 2)
                tree[p].l0 = tree[p].r0 = val, tree[p].l1 = tree[p].r1 = ~0;
            else
                tree[p].l0 = tree[p].r0 = val, tree[p].l1 = tree[p].r1 = ~val;
            return;
        }
        int mid = (l + r) >> 1;
        qp <= mid ? update(p << 1, l, mid, qp, opt, val) : update(p << 1 | 1, mid + 1, r, qp, opt, val);
        tree[p] = tree[p << 1] + tree[p << 1 | 1];
        return;
    }
    Node merge(int p, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) return tree[p];
        int mid = (l + r) >> 1;
        Node answer;
        if (ql <= mid) answer = answer + merge(p << 1, l, mid, ql, qr);
        if (qr > mid) answer = answer + merge(p << 1 | 1, mid + 1, r, ql, qr);
        return answer;
    }
    inline void resize(int tn) { return n = tn, void(); }
    inline void update(int p, int opt, unsigned long long val) { return update(1, 1, n, p, opt, val); }
    inline Node merge(int l, int r) { return merge(1, 1, n, l, r); }
};

SegmentTree tree;

int opt[maxn];
unsigned long long val[maxn];

vector<vector<int> > graph;

int dep[maxn], siz[maxn], son[maxn], top[maxn], fa[maxn], dfn[maxn];

void dfs1(int p, int pre = 0) {
    dep[p] = dep[pre] + 1, fa[p] = pre, siz[p] = 1;
    for (vector<int>::iterator i = graph[p].begin(); i != graph[p].end(); i++)
        if (*i != pre) {
            dfs1(*i, p);
            if (!son[p] || siz[*i] > siz[son[p]]) son[p] = *i;
            siz[p] += siz[*i];
        }
    return;
}
int tim = 0;
void dfs2(int p, int t) {
    dfn[p] = ++tim, top[p] = t;
    if (son[p]) dfs2(son[p], t);
    for (vector<int>::iterator i = graph[p].begin(); i != graph[p].end(); i++)
        if (*i != fa[p] && *i != son[p]) dfs2(*i, *i);
    return;
}
int lca(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

int main() {
    int n = read<int>(), m = read<int>(), k = read<int>();
    graph.resize(n + 1);
    for (register int i = 1; i <= n; i++) opt[i] = read<int>(), val[i] = read<unsigned long long>();
    for (register int i = 1, from, to; i < n; i++) {
        from = read<int>(), to = read<int>();
        graph[from].push_back(to), graph[to].push_back(from);
    }
    dfs1(1), dfs2(1, 1);
    tree.resize(n);
    for (register int i = 1; i <= n; i++) tree.update(dfn[i], opt[i], val[i]);
    while (m--)
        if (read<int>() == 1) {
            int x = read<int>(), y = read<int>(), t = lca(x, y);
            unsigned long long z = read<unsigned long long>();
            SegmentTree::Node record1, record2;
            while (top[x] != top[t]) record1 = tree.merge(dfn[top[x]], dfn[x]) + record1, x = fa[top[x]];
            record1 = tree.merge(dfn[t], dfn[x]) + record1;
            while (top[y] != top[t]) record2 = tree.merge(dfn[top[y]], dfn[y]) + record2, y = fa[top[y]];
            if (y != t) record2 = tree.merge(dfn[t] + 1, dfn[y]) + record2;
            swap(record1.l0, record1.r0), swap(record1.l1, record1.r1);
            SegmentTree::Node record = record1 + record2;
            bool up = true;
            unsigned long long answer = 0;
            for (register int i = k - 1; ~i; i--) {
                unsigned long long l0 = (record.l0 >> i & 1), l1 = (record.l1 >> i & 1), t = (z >> i & 1);
                if (!up)
                    answer |= max(l0, l1) << i;
                else if (t == 0)
                    answer |= l0 << i;
                else if (l0 >= l1)
                    answer |= l0 << i, up = false;
                else
                    answer |= l1 << i;
            }
            cout << answer << endl;
        } else {
            int p = read<int>(), opt = read<int>();
            unsigned long long val = read<unsigned long long>();
            tree.update(dfn[p], opt, val);
        }
    return 0;
}

我缓慢吐出一串啊吧啊吧并不再想说话