HDU7162. Equipment Upgrade (2022杭电多校第3场1001)
题意
有一件装备,一开始是 /(0/) 级,可以强化它,当它在第 /(i/) 级时,需要花费 /(c_i/) 强化它,有 /(p_i/) 的概率强化成功(升高一级),/(1-p_i/) 的概率强化失败(降 /(1/) 至 /(i/) 级),其中降 /(j/) 级的权重为 /(w_j/),也就是说有 /((1-p_i)/frac{w_j}{/sum/limits_{k=1}^i w_k}/) 的概率降 /(j/) 级。求从 /(0/) 级升级到 /(n/) 级的期望花费。
分析
设 /(f(i)/) 为处于第 /(i/) 级的期望代价,即要算 /(f(0)/)。
记 /(S_i=/sum/limits_{j=1}^i w_j/)
根据题意,可以写出式子
/[f(i)=p_if(i+1)+/frac{1-p_i}{S_i}/sum_{j=1}^i w_jf(i-j)+c_i
/]
上面的式子后面的和式已经出现了卷积形式,但是不便于求解。
容易发现,上面式子包含 /(f(0),f(1),…,f(i+1)/)。稍作移项,得到
/[f(i+1)=/frac{f(i)-c_i-/frac{1-p_i}{S_i}/sum/limits_{j=1}^i w_jf(i-j)}{p_i}
/]
这说明如果已知 /(f(0),f(1),…,f(i)/),就能线性地推出 /(f(i+1)/)
而现在已知 /(f(n)/) 是 /(0/),要求 /(f(0)/)。我们可以尝试将 /(f(i)/) 线性地用 /(f(0)/) 表示,具体地说就是设
/[f(i)=a_i f(0)+b_i
/]
我们只要求出 /(a_i/) 和 /(b_i/),最终答案就是 /(-/frac{b_n}{a_n}/)
根据 /(f(i)/) 的递推式和所设 /(a_i/) 与 /(b_i/) 的含义,可以推得 /(a_i/) 和 /(b_i/) 的递推式。
/[a_{i+1}=/frac{a_i-/frac{1-p_i}{S_i}/sum/limits_{j=1}^i w_j a_{i-j}}{p_i}//
b_{i+1}=/frac{b_i-c_i-/frac{1-p_i}{S_i}/sum/limits_{j=1}^i w_j b_{i-j}}{p_i}
/]
这样就可以CDQ分治配合卷积 /(O(n/log^2n)/) 求解 /(a_i/) 和 /(b_i/) 了。
CDQ分治的时候注意 /(w_i/) 和 /(a_j/) / /(b_j/) 贡献在了 /(a_{i+j+1}/) / /(b_{i+j+1}/) 上。调用NTT卷积时,要仔细计算好相应的位置。
代码
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
namespace NTT {
typedef int Lint;
typedef long long LLint;
// 2的幂次
const int maxn = (1 << 21) + 10;
const Lint mod = 998244353;
const Lint g = 3;
Lint fpow(Lint a, Lint b, Lint mod) {
Lint res = 1;
for (; b; b >>= 1) {
if (b & 1)
res = (LLint)res * a % mod;
a = (LLint)a * a % mod;
}
return res;
}
inline Lint add(Lint a, Lint b) {
a += b;
return a >= mod ? a - mod : a;
}
inline Lint mul(Lint a, Lint b) {
return (LLint)a * b % mod;
}
int r[maxn];
void cal_r(int n) {
for (int i = 0; i < n; i++) {
r[i] = (i & 1) * (n >> 1) + (r[i >> 1] >> 1);
}
}
void dft(Lint* a, int n, int type) {
for (int i = 0; i < n; i++)
if (i < r[i])
swap(a[i], a[r[i]]);
for (int i = 1; i < n; i <<= 1) {
int p = i << 1;
Lint w = fpow(g, (mod - 1) / p, mod);
if (type == -1)
w = fpow(w, mod - 2, mod);
for (int j = 0; j < n; j += p) {
Lint t = 1;
for (int k = 0; k < i; k++, t = mul(t, w)) {
Lint tmp = mul(a[j + k + i], t);
a[j + k + i] = add(a[j + k], mod - tmp);
a[j + k] = add(a[j + k], tmp);
}
}
}
if (type == -1) {
Lint inv = fpow(n, mod - 2, mod);
for (int i = 0; i < n; i++)
a[i] = mul(a[i], inv);
}
}
Lint p[maxn], q[maxn];
vector<Lint> poly_mul(const vector<Lint>& a, const vector<Lint>& b) {
vector<Lint> res;
int n = a.size(), m = b.size();
res.resize(n + m - 1);
int len = n + m - 1;
int lim = 1;
while (lim < len)
lim <<= 1;
copy(a.begin(), a.end(), p);
fill(p + n, p + lim, 0);
copy(b.begin(), b.end(), q);
fill(q + m, q + lim, 0);
cal_r(lim);
dft(p, lim, 1), dft(q, lim, 1);
for (int i = 0; i < lim; i++)
p[i] = mul(p[i], q[i]);
dft(p, lim, -1);
for (int i = 0; i < n + m - 1; i++)
res[i] = p[i];
return res;
}
}; // namespace NTT
int inv(int a) {
return NTT::fpow(a, NTT::mod - 2, NTT::mod);
}
using NTT::add;
using NTT::mod;
using NTT::mul;
using NTT::poly_mul;
int n;
const int maxn = (1 << 18) + 10;
const int inv100 = 828542813;
int lim;
int w[maxn], p[maxn], sum[maxn], inv_sum[maxn], inv_p[maxn];
int a[maxn], b[maxn], c[maxn];
void solve_ab(int l, int r) {
if (l == r) {
if (l == 0)
return;
a[l] = mul(add(a[l - 1], mod - mul(add(1, mod - p[l - 1]), mul(inv_sum[l - 1], a[l]))), inv_p[l - 1]);
b[l] = mul(add(add(b[l - 1], mod - c[l - 1]), mod - mul(add(1, mod - p[l - 1]), mul(inv_sum[l - 1], b[l]))), inv_p[l - 1]);
return;
}
int mid = l + r >> 1;
solve_ab(l, mid);
vector<int> P(r - l), Q(mid - l + 1);
for (int i = 0; i < r - l; i++)
P[i] = w[i];
for (int i = l; i <= mid; i++)
Q[i - l] = a[i];
vector<int> res = poly_mul(P, Q);
for (int i = mid + 1; i <= r; i++) {
a[i] = add(a[i], res[i - l - 1]);
}
for (int i = 0; i < r - l; i++)
P[i] = w[i];
for (int i = l; i <= mid; i++)
Q[i - l] = b[i];
res = poly_mul(P, Q);
for (int i = mid + 1; i <= r; i++) {
b[i] = add(b[i], res[i - l - 1]);
}
solve_ab(mid + 1, r);
}
void solve() {
cin >> n;
for (int i = 0; i < n; i++) {
cin >> p[i] >> c[i];
p[i] = mul(p[i], inv100);
inv_p[i] = inv(p[i]);
}
for (int i = 1; i <= n - 1; i++) {
cin >> w[i];
sum[i] = add(sum[i - 1], w[i]);
inv_sum[i] = inv(sum[i]);
}
fill(a + 1, a + 1 + n, 0);
fill(b + 1, b + 1 + n, 0);
solve_ab(0, n);
cout << mul(mod - b[n], inv(a[n])) << '/n';
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
a[0] = 1;
int T;
cin >> T;
while (T--)
solve();
return 0;
}
原创文章,作者:,如若转载,请注明出处:https://blog.ytso.com/277190.html