CF1626F A Random Code Problem

发布于 2022-02-01  2.53k 次阅读


题目链接

题目大意

给出一个长度为 $n$ 的序列,你需要进行 $k$ 次操作,第 $i$ 次操作将会任意选择一个序列元素,将它的值加入到答案后将该数减去其模 $i$ 的值。求 $k$ 次操作后你的答案的期望。

$n \le 10^7,~k \le 17,~a_i < 998244353$

分析

我们发现若第 $x_1,~x_2,\dots,~x_n$ 次操作均在某一值为 $v$ 的元素上进行,该元素最终值为 $v - v \bmod \mathrm{lcm} _i ~ x_i$。由于操作数量不超过 $17$,容易发现数组内的每一个元素在结束后都一定不小于 $v - v \bmod \mathrm{lcm} _{1 \le i \le 17} ~ i$。

我们令 $L = \mathrm{lcm}_{1 \le i \le k}~i$。则我们可以将每个元素 $v$ 分为两部分,第一部分为 $v - v \bmod L$,第二部分为 $v \% L$。考虑对这两部分分开进行计算。

第一部分由于在经过任意操作后该数的第一部分均不会发生变化,因此第一部分对答案的贡献即为该数被操作到的期望数,即 $(v - v \bmod L) \times \frac {k \times n^{k-1}} {n^k}$。

第二部分值小于 $L$,由于值域较小而数组元素数量较多,考虑使用桶进行计数。不难计算出 $f[i][j]$ 表示所有使用了 $i$ 次操作的方案(共 $n^i$ 种)中的值为 $j$ 的元素数量之和,通过对每一个状态统计下次操作若操作到该值时该状态对总答案的贡献即可求出答案: $\frac 1 {n^k} \times \sum_i \sum_j j \times f[i][j] \times n^{k-i-1}$。

当 $k=17$ 时 $L$ 可能过大以至于 DP 时间复杂度不可接受。容易发现第 $k$ 次操作的修改操作对总答案无影响,因此令 $L = \mathrm{lcm}_{1 \le i < k}~i$ 即可。

总时间复杂度 $O(n + k \times \mathrm{lcm}_{1 \le i < k}~i)$。

代码

View on GitHub

/**
 * @file 1626F.cpp
 * @author Macesuted (i@macesuted.moe)
 * @date 2022-02-01
 *
 * @copyright Copyright (c) 2022
 * @brief
 *      My tutorial: https://macesuted.moe/article/cf1626f
 */

#include <bits/stdc++.h>
using namespace std;

#define MP make_pair
#define MT make_tuple

namespace io {
#define SIZE (1 << 20)
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55];
int f, qr;
inline void flush(void) { return fwrite(obuf, 1, oS - obuf, stdout), oS = obuf, void(); }
inline char getch(void) {
    return (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS++)) : *iS++);
}
void putch(char x) {
    *oS++ = x;
    if (oS == oT) flush();
    return;
}
string getstr(void) {
    string s = "";
    char c = getch();
    while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = getch();
    while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) s.push_back(c), c = getch();
    return s;
}
void putstr(string str, int begin = 0, int end = -1) {
    if (end == -1) end = str.size();
    for (int i = begin; i < end; i++) putch(str[i]);
    return;
}
template <typename T>
T read() {
    T x = 0;
    for (f = 1, c = getch(); c < '0' || c > '9'; c = getch())
        if (c == '-') f = -1;
    for (x = 0; c <= '9' && c >= '0'; c = getch()) x = x * 10 + (c & 15);
    return x * f;
}
template <typename T>
void write(const T& t) {
    T x = t;
    if (!x) putch('0');
    if (x < 0) putch('-'), x = -x;
    while (x) qu[++qr] = x % 10 + '0', x /= 10;
    while (qr) putch(qu[qr--]);
    return;
}
struct Flusher_ {
    ~Flusher_() { flush(); }
} io_flusher_;
}  // namespace io
using io::getch;
using io::getstr;
using io::putch;
using io::putstr;
using io::read;
using io::write;

bool mem1;

#define maxk 18
#define maxn 10000005
#define maxL 720725
#define mod 998244353

int a[maxn];
long long f[maxk][maxL];

long long Pow(long long a, long long x) {
    long long ans = 1;
    while (x) {
        if (x & 1) ans = ans * a % mod;
        a = a * a % mod, x >>= 1;
    }
    return ans;
}

void solve(void) {
    int n = read<int>(), a0 = read<int>(), x = read<int>(), y = read<int>(), k = read<int>(), M = read<int>();
    a[1] = a0;
    for (int i = 2; i <= n; i++) a[i] = (1LL * a[i - 1] * x % M + y) % M;
    int L = 720720;
    long long ans = 0, coeff = k * Pow(n, k - 1) % mod;
    for (int i = 1; i <= n; i++) {
        int rest = a[i] % L;
        f[0][rest]++;
        ans = (ans + 1LL * (a[i] - rest) * coeff) % mod;
    }
    for (int i = 0; i < k; i++) {
        long long t = Pow(n, k - i - 1);
        for (int j = 0; j < L; j++)
            if (f[i][j]) {
                f[i + 1][j] = (f[i + 1][j] + 1LL * f[i][j] * (n - 1) % mod) % mod;
                f[i + 1][j - j % (i + 1)] = (f[i + 1][j - j % (i + 1)] + f[i][j]) % mod;
                ans = (ans + 1LL * j * f[i][j] % mod * t) % mod;
            }
    }
    cout << ans << endl;
    return;
}

bool mem2;

int main() {
#ifdef MACESUTED
    cerr << "Memory: " << abs(&mem1 - &mem2) / 1024. / 1024. << "MB" << endl;
#endif

    int _ = 1;
    while (_--) solve();

#ifdef MACESUTED
    cerr << "Time: " << clock() * 1000. / CLOCKS_PER_SEC << "ms" << endl;
#endif
    return 0;
}

我缓慢吐出一串啊吧啊吧并不再想说话