LG5608 [Ynoi2013] 文化课

发布于 2021-08-13  3.69k 次阅读


题目链接

题目大意

现在有一个数字序列 ${a_1,~a_2,\dots,~a_n}$ 和一个运算符序列 ${p_1,~p_2,\dots,~a_{n-1}}$。

定义 $w(l,~r)$ 表示 $a_l~p_l~a_{l+1}~p_{l+1}~\dots~a_{r-1}~p_{r-1}~a_r$ 对 $10^9+7$ 取模后的结果。

现有 $m$ 个操作:

  • 将 $a_l \sim a_r$ 修改为 $x$。
  • 将 $p_l \sim p_r$ 修改为 $x$。
  • 求 $w(l,~r)$ 的值。

$1 \le n,~m \le 10^5,~1 \le a_i < 2^{32},~p_i \in {+,~\times}$

分析

考虑我们如何暴力计算 $w(l,~r)$,我们会先将 $[l,~r]$ 区间切分为若干满足每段内所有符号均为 $\times$ 的极长段。如 $1+3\times5\times7+9\times11$ 将被分割为 $[1],~[3,~5,~7],~[9,~11]$。分割完后我们对每个极长段求出段内乘积,再将所有段的结果加起来即为我们所求的 $w(l,~r)$。

先考虑没有修改操作只有查询操作的情况,对于每一个查询区间我们需要知道区间内所有极长段的乘积之和,容易想到使用线段树维护。线段树上每一个结点维护对应区间的最左端极长段乘积,最右端极长段乘积和其他极长段乘积之和。合并两个区间信息时判断左结点的最右端极长段与右结点的最左端极长段是否能够连接即可。

考虑加入 1 操作,对 $a$ 序列的区间修改操作在线段树上体现为对 $O(\log n)$ 个区间的整段修改操作。我们发现对于一个线段树结点,当他对应的区间被整段修改后其所有极长段左右端点均没有发生变化,而所有极长段的段内乘积会发生改变。我们考虑对于每个节点维护一个桶维护其所有非最左端也非最右端的极长段长度,在将该节点整段修改为 $x$ 时只需要将答案更新为 $\sum_{i} x^{i} \times midLen[i]$ 即可。同时我们也需要记录最左端极长段长度和最右端极长段长度,这样在合并线段树结点信息时即可将左结点的最右端极长段与右结点的最左端极长段合并后构成的新极长段长度加入桶中。

考虑加入 2 操作,同样我们也可以将对 $p$ 序列的区间修改在线段树上体现为对 $O(\log n)$ 个区间的整段修改操作。由于修改为 $+$ 和修改为 $\times$ 的情况不同,我们分开讨论:

  • 整段修改为 $+$: 修改后该区间内会产生 $len$ 个长为 $1$ 的极长段,$ans$ 将变为 $\sum_{i=l+1}^{r-1} a_i$。为此对每个节点我们维护整段元素和以快速维护此修改操作。
  • 整段修改为 $\times$: 修改后该区间内会产生 $1$ 个长为 $len$ 的极长段,此时该极长段的乘积为 $\prod_{i=l}^{r} a_i$。为此对每个节点我们维护整段元素乘积以快速维护此修改操作。

在线段树上维护上述所有信息即可,此时时间复杂度为 $O(n \times \log n + m \times n \times \log n)$,空间复杂度为 $O(n \times \log n)$。时间复杂度与空间复杂度均无法通过此题。

优化 1

对于每个线段树结点,其区间内所有极长段的长度之和为 $len$,因此最多只会存在 $\sqrt {len}$ 种不同的极长段长度,将相同长度的极长段的信息在一起存储,使用大小为 $\sqrt {len}$ 的 vector 存储即可。

此时时间复杂度 $O(n + m \times \sqrt n \times \log n)$,空间复杂度 $O(n)$。时间复杂度仍无法通过此题。

优化 2

在区间修改元素值时我们对 $O(\sqrt n)$ 个长度都 $O(\log n)$ 求该长度对应区间修改后的乘积。考虑存储连续段长度时按连续段长度升序存储,需要对第 $i$ 个长度求值时从第 $i-1$ 个长度对应的答案转移过来。因为 $O(\sum_{i} \log (a_i - a_{i-1})) = O(\log a_n) = O(\sqrt len)$,所以花在对 $O(\sqrt n)$ 个长度求值的时间复杂度转为 $O(\sqrt n)$。

此时时间复杂度 $O(n + m \times \sqrt n)$,空间复杂度 $O(n)$,可以通过此题。

也可以考虑以恰当块长对线段树底层进行分块以继续减少空间占用。

代码

View on GitHub

/**
 * @author Macesuted (i@macesuted.moe)
 * @copyright Copyright (c) 2021
 * @brief
 *      My solution: https://macesuted.moe/article/lg5608
 */

#include <bits/stdc++.h>
using namespace std;

namespace io {
#define SIZE (1 << 20)
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55];
int f, qr;
inline void flush(void) { return fwrite(obuf, 1, oS - obuf, stdout), oS = obuf, void(); }
inline char getch(void) { return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS++)) : *iS++); }
inline void putch(char x) {
    *oS++ = x;
    if (oS == oT) flush();
    return;
}
string getstr(void) {
    string s = "";
    char c = getch();
    while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = getch();
    while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) s.push_back(c), c = getch();
    return s;
}
void putstr(string str, int begin = 0, int end = -1) {
    if (end == -1) end = str.size();
    for (register int i = begin; i < end; i++) putch(str[i]);
    return;
}
template <typename T>
inline T read() {
    register T x = 0;
    for (f = 1, c = getch(); c < '0' || c > '9'; c = getch())
        if (c == '-') f = -1;
    for (x = 0; c <= '9' && c >= '0'; c = getch()) x = x * 10 + (c & 15);
    return x * f;
}
template <typename T>
inline void write(const T& t) {
    register T x = t;
    if (!x) putch('0');
    if (x < 0) putch('-'), x = -x;
    while (x) qu[++qr] = x % 10 + '0', x /= 10;
    while (qr) putch(qu[qr--]);
    return;
}
struct Flusher_ {
    ~Flusher_() { flush(); }
} io_flusher_;
}  // namespace io
using io::getch;
using io::getstr;
using io::putch;
using io::putstr;
using io::read;
using io::write;

#define maxn 100005
#define sqrtn 320
#define mod 1000000007
#define leafLen 50

typedef pair<int, int> pii;

long long Pow(long long a, long long x) {
    long long ans = 1;
    while (x) {
        if (x & 1) ans = ans * a % mod;
        a = a * a % mod;
        x >>= 1;
    }
    return ans;
}

inline int Mod(int x) { return x >= mod ? x - mod : x; }

pii cache[sqrtn];

class SegmentTree {
   private:
    struct Node {
        bool op, hasLazyOp, lazyOp, hasLazyNum;
        long long allMul, lMul, rMul;
        int allPlus, ans, lazyNum;
        int lNum, rNum, lLen, rLen;
        pii midLen[sqrtn];
        int tail;
        int len, sqrtLen;
        Node *l, *r;
        Node(void) { hasLazyOp = hasLazyNum = false, l = r = NULL, tail = 0; }
        void add(int val, int cnt = 1) {
            if (val == 0) return;
            for (register int i = 1; i <= tail; i++)
                if (midLen[i].first == val) return midLen[i].second += cnt, void();
            int p = tail++;
            while (p > 0 && midLen[p].first > val) swap(midLen[p], midLen[p + 1]), p--;
            midLen[p + 1] = (pii){val, cnt};
            return;
        }
        Node operator+(const Node& oth) const {
            Node a = *this, b = oth;
            a.hasLazyOp = a.hasLazyNum = false;
            a.allMul = a.allMul * b.allMul % mod;
            a.allPlus = Mod(a.allPlus + b.allPlus);
            a.rNum = b.rNum;
            a.sqrtLen = sqrt(a.len += b.len);
            int ctail = a.tail;
            for (register int i = 1; i <= a.tail; i++) cache[i] = a.midLen[i];
            int pb = 1, pc = 1;
            a.tail = 0;
            while (pb <= b.tail && pc <= ctail)
                if (b.midLen[pb].first == cache[pc].first)
                    a.midLen[++a.tail] = (pii){b.midLen[pb].first, b.midLen[pb].second + cache[pc].second}, pb++, pc++;
                else if (b.midLen[pb].first < cache[pc].first)
                    a.midLen[++a.tail] = b.midLen[pb], pb++;
                else
                    a.midLen[++a.tail] = cache[pc], pc++;
            while (pb <= b.tail) a.midLen[++a.tail] = b.midLen[pb], pb++;
            while (pc <= ctail) a.midLen[++a.tail] = cache[pc], pc++;
            if (!a.op) {
                if (!a.lLen) swap(a.lLen, a.rLen), swap(a.lMul, a.rMul);
                if (!b.rLen) swap(b.lLen, b.rLen), swap(b.lMul, b.rMul);
                a.ans = Mod(Mod(a.ans + b.ans) + Mod(a.rMul + b.lMul));
                a.add(a.rLen), a.add(b.lLen);
                a.rLen = b.rLen, a.rMul = b.rMul;
            } else {
                if (!a.rLen) swap(a.lLen, a.rLen), swap(a.lMul, a.rMul);
                if (!b.lLen) swap(b.lLen, b.rLen), swap(b.lMul, b.rMul);
                int nLen = a.rLen + b.lLen, nMul = a.rMul * b.lMul % mod;
                a.ans = Mod(a.ans + b.ans);
                a.rLen = b.rLen, a.rMul = b.rMul;
                if (!a.lLen)
                    a.lLen = nLen, a.lMul = nMul;
                else if (!a.rLen)
                    a.rLen = nLen, a.rMul = nMul;
                else
                    a.ans = Mod(a.ans + nMul), a.add(nLen);
            }
            a.op = b.op;
            return a;
        }
        int getAns(void) {
            int ans = this->ans;
            if (lLen) ans = Mod(ans + lMul);
            if (rLen) ans = Mod(ans + rMul);
            return ans;
        }
    };

    Node* root;

    int a[maxn], op[maxn];

    int n;

    Node merge(Node* l, Node* r) {
        Node ans = *l + *r;
        ans.l = l, ans.r = r;
        return ans;
    }
    Node reCalc(int l, int r) {
        Node p;
        p.op = op[r], p.ans = 0;
        p.tail = 0;
        p.lNum = a[l], p.rNum = a[r];
        p.lLen = 1, p.lMul = a[l], p.rLen = 0, p.rMul = 0;
        p.sqrtLen = sqrt(p.len = r - l + 1);
        while (l + p.lLen <= r && op[l + p.lLen - 1]) p.lMul = p.lMul * a[l + p.lLen] % mod, p.lLen++;
        if (l + p.lLen - 1 < r) {
            p.rLen = 1, p.rMul = a[r];
            while (op[r - p.rLen]) p.rMul = p.rMul * a[r - p.rLen] % mod, p.rLen++;
            int tl = l + p.lLen, tr = r - p.rLen;
            if (tl <= tr) {
                long long last = a[tl];
                int lastPos = tl;
                for (register int i = tl; i < tr; i++)
                    if (op[i])
                        last = last * a[i + 1] % mod;
                    else
                        p.ans = Mod(p.ans + last), p.add(i - lastPos + 1), last = a[lastPos = i + 1];
                p.add(tr - lastPos + 1), p.ans = Mod(p.ans + last);
            }
        }
        p.allMul = 1, p.allPlus = 0;
        for (register int i = l; i <= r; i++) p.allMul = p.allMul * a[i] % mod, p.allPlus = Mod(p.allPlus + a[i]);
        return p;
    }
    void modifyNum(Node* p, int l, int r, int num) {
        if (r - l + 1 < leafLen) {
            for (register int i = l; i <= r; i++) a[i] = num;
            *p = reCalc(l, r);
            return;
        }
        p->allMul = Pow(num, r - l + 1), p->allPlus = 1LL * (r - l + 1) * num % mod, p->ans = 0;
        long long lastPow = 1;
        int lastPos = 0;
        for (register int i = 1; i <= p->tail; i++)
            lastPow = lastPow * Pow(num, p->midLen[i].first - lastPos) % mod, lastPos = p->midLen[i].first,
            p->ans = (p->ans + lastPow * p->midLen[i].second) % mod;
        p->lNum = p->rNum = num;
        if (p->lLen) p->lMul = Pow(num, p->lLen);
        if (p->rLen) p->rMul = Pow(num, p->rLen);
        p->hasLazyNum = true, p->lazyNum = num;
        return;
    }
    inline void modifyOp(Node* p, int l, int r, bool _op) {
        if (r - l + 1 < leafLen) {
            for (register int i = l; i <= r; i++) op[i] = _op;
            *p = reCalc(l, r);
            return;
        }
        p->op = _op;
        if (!_op) {
            p->ans = Mod(Mod(p->allPlus + mod - p->lNum) + mod - p->rNum);
            p->lLen = 1, p->lMul = p->lNum;
            p->rLen = 1, p->rMul = p->rNum;
            p->tail = 0, p->add(1, r - l - 1);
        } else {
            p->ans = 0;
            p->lLen = r - l + 1, p->lMul = p->allMul;
            p->rLen = 0, p->rMul = 0;
            p->tail = 0;
        }
        p->hasLazyOp = true, p->lazyOp = _op;
        return;
    }
    inline void pushDown(Node* p, int l, int r) {
        int mid = (l + r) >> 1;
        if (p->hasLazyOp) {
            p->hasLazyOp = false;
            modifyOp(p->l, l, mid, p->lazyOp), modifyOp(p->r, mid + 1, r, p->lazyOp);
        }
        if (p->hasLazyNum) {
            p->hasLazyNum = false;
            modifyNum(p->l, l, mid, p->lazyNum), modifyNum(p->r, mid + 1, r, p->lazyNum);
        }
        return;
    }
    void build(Node*& p, int l, int r, int _a[], bool _op[]) {
        if (p == NULL) p = new Node();
        if (r - l + 1 < leafLen) {
            for (register int i = l; i <= r; i++) a[i] = _a[i], op[i] = _op[i];
            p->sqrtLen = sqrt(p->len = r - l + 1);
            *p = reCalc(l, r);
            return;
        }
        int mid = (l + r) >> 1;
        build(p->l, l, mid, _a, _op), build(p->r, mid + 1, r, _a, _op);
        *p = merge(p->l, p->r);
        return;
    }
    void updateNum(Node* p, int l, int r, int ql, int qr, int val) {
        if (ql <= l && r <= qr) return modifyNum(p, l, r, val);
        if (r - l + 1 < leafLen) {
            for (register int i = max(l, ql); i <= min(r, qr); i++) a[i] = val;
            *p = reCalc(l, r);
            return;
        }
        pushDown(p, l, r);
        int mid = (l + r) >> 1;
        if (ql <= mid) updateNum(p->l, l, mid, ql, qr, val);
        if (qr > mid) updateNum(p->r, mid + 1, r, ql, qr, val);
        *p = merge(p->l, p->r);
        return;
    }
    void updateOp(Node* p, int l, int r, int ql, int qr, bool _op) {
        if (ql <= l && r <= qr) return modifyOp(p, l, r, _op);
        if (r - l + 1 < leafLen) {
            for (register int i = max(l, ql); i <= min(r, qr); i++) op[i] = _op;
            *p = reCalc(l, r);
            return;
        }
        pushDown(p, l, r);
        int mid = (l + r) >> 1;
        if (ql <= mid) updateOp(p->l, l, mid, ql, qr, _op);
        if (qr > mid) updateOp(p->r, mid + 1, r, ql, qr, _op);
        *p = merge(p->l, p->r);
        return;
    }
    Node getAns(Node* p, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) return *p;
        if (r - l + 1 < leafLen) return reCalc(max(l, ql), min(r, qr));
        int mid = (l + r) >> 1;
        pushDown(p, l, r);
        if (qr <= mid) return getAns(p->l, l, mid, ql, qr);
        if (ql > mid) return getAns(p->r, mid + 1, r, ql, qr);
        return getAns(p->l, l, mid, ql, qr) + getAns(p->r, mid + 1, r, ql, qr);
    }

   public:
    SegmentTree(void) { root = NULL; }
    inline void resize(int _n) { return n = _n, void(); }
    inline void build(int a[], bool op[]) { return build(root, 1, n, a, op); }
    inline void updateNum(int l, int r, int val) { return updateNum(root, 1, n, l, r, val); }
    inline void updateOp(int l, int r, bool op) { return updateOp(root, 1, n, l, r, op); }
    inline int getAns(int l, int r) { return getAns(root, 1, n, l, r).getAns(); }
};

SegmentTree tree;

int a[maxn];
bool op[maxn];

int main() {
    int n = read<int>(), m = read<int>();
    for (register int i = 1; i <= n; i++) a[i] = read<long long>() % mod;
    for (register int i = 1; i < n; i++) op[i] = read<int>();
    tree.resize(n), tree.build(a, op);
    while (m--) {
        int opt = read<int>();
        if (opt == 1) {
            int l = read<int>(), r = read<int>();
            tree.updateNum(l, r, read<long long>() % mod);
        } else if (opt == 2) {
            int l = read<int>(), r = read<int>();
            tree.updateOp(l, r, read<int>());
        } else {
            int l = read<int>(), r = read<int>();
            write(tree.getAns(l, r)), putch('\n');
        }
    }
    return 0;
}

我缓慢吐出一串啊吧啊吧并不再想说话