LG6144 [USACO20FEB]Help Yourself P【DP,组合数,线段树】


传送门

思路

考虑 DP,设 /(f_{i,j,k}/) 表示前 /(i/) 条线段,连通块最右端的点为 /(j/) 的所有子集的连通块个数的 /(k/) 次方之和。初值 /(f_{0,0,0} = 1/),答案为 /(/sum f_{n,j,K}/)。

把线段按照左端点排序,考虑加入第 /(i/) 条线段后对答案的影响,设 /(j/) 为加入 /(i/) 之前连通块的右端点:

  • 对于 /(j < l_i/) 的情况,加入 /(i/) 后连通块个数会增加 /(1/),并且右端点变为 /(r_i/)。设 /(f_{i-1,j,k}/) 对应的子集为 /(S/),/(g(S)/) 为 /(S/) 的连通块数,有转移:

/[f_{i,r_i,k} /gets /sum_S (g(S) + 1)^k = /sum_S /sum_{p=0}^k /binom{k}{p} g(S)^p = /sum_{j=0}^{l_i – 1} /sum_{p=0}^k /binom{k}{p} f_{i-1,j,p}
/]

  • 对于 /(l_i /leq j /leq r_i/) 的情况,加入 /(i/) 后连通块数量不变,但右端点变为 /(r_i/),因此转移为 /(f_{i,r_i,k} /gets f_{i-1,j,k}/)。

  • 对于 /(j > r_i/) 的情况,由于我们将线段按照左端点排序,那么 /(i/) 一定被原先的连通块包含,因此加入 /(i/) 后连通块数量不变,右端点也不变。但由于 /(i/) 被包含,则选或不选 /(i/) 对连通块个数都没有影响,故转移为 /(f_{i,j,k} /gets f_{i-1,j,k} /times 2/)。

直接 DP 复杂度巨大,考虑优化。首先 /(i/) 这维可以滚掉,然后对 /(j/) 这维用线段树维护,线段树上每个节点维护子树内位置连通块数量的 /(1 /sim K/) 次方和。

  • 对于 /(j < l_i/) 的情况,利用分配率,可以先对 /([0,l_i-1]/) 区间查询出 /(1 /sim K/) 次方和,然后二项式定理 /(O(K^2)/) 算对 /(f_{i,r_i,1 /sim K}/) 的贡献,最后对位置 /(r_i/) 单点加即可。

  • 对于 /(l_i /leq j /leq r_i/) 的情况,先对 /([l_i,r_i-1]/) 区间查询出 /(1 /sim K/) 次方和,然后位置 /(r_i/) 单点加即可。

  • 对于 /(j > r_i/) 的情况,对 /([r+1,N]/) 区间乘 /(2/) 即可。

时间复杂度为 /(O(n /log K + nK^2)/)。拆 /(K/) 次幂也可以使用斯特林数,复杂度是一样的。

Code
/*
最黯淡的一个 梦最为炽热
万千孤单焰火 让这虚构灵魂鲜活
至少在这一刻 热爱不问为何
存在为将心声响彻
*/
#include <bits/stdc++.h>
#define pii pair<int, int>
#define mp(x, y) make_pair(x, y)
#define pb push_back
#define eb emplace_back
#define fi first
#define se second
#define int long long
#define mem(x, v, n) memset(x, v, sizeof(int) * (n))
#define mcpy(x, y, n) memcpy(x, y, sizeof(int) * (n))
#define lob lower_bound
#define upb upper_bound
using namespace std;

inline int read() {
	int x = 0, w = 1;char ch = getchar();
	while (ch > '9' || ch < '0') { if (ch == '-')w = -1;ch = getchar(); }
	while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
	return x * w;
}

inline int min(int x, int y) { return x < y ? x : y; }
inline int max(int x, int y) { return x > y ? x : y; }

const int MN = 1e5 + 5;
const int Mod = 1e9 + 7;

inline int qPow(int a, int b = Mod - 2, int ret = 1) {
    while (b) {
        if (b & 1) ret = ret * a % Mod;
        a = a * a % Mod, b >>= 1;
    }
    return ret;
}

// #define dbg

int N, K, Lim, C[25][25];
pii p[MN];

inline void Add(int &x, int y) {
    x += y; if (x >= Mod) x -= Mod;
}

const int MS = MN << 3;
#define ls o << 1
#define rs o << 1 | 1
#define mid ((l + r) >> 1)
#define LS ls, l, mid
#define RS rs, mid + 1, r
struct Dat {
    int a[13];
};
int tr[MS][12], tag[MS];
inline void Pushup(int o) {
    for (int i = 0; i <= K; i++) tr[o][i] = (tr[ls][i] + tr[rs][i]) % Mod;
}
inline void Pushdown(int o) {
    if (tag[o] != 1) {
        tag[ls] = tag[ls] * tag[o] % Mod, tag[rs] = tag[rs] * tag[o] % Mod;
        for (int i = 0; i <= K; i++) {
            tr[ls][i] = tr[ls][i] * tag[o] % Mod;
            tr[rs][i] = tr[rs][i] * tag[o] % Mod;
        }
        tag[o] = 1;
    }
}
inline void Add(int o, int l, int r, int p, Dat v) {
    if (l == r) {
        for (int i = 0; i <= K; i++) Add(tr[o][i], v.a[i]);
        return;
    }
    Pushdown(o), (p <= mid ? Add(LS, p, v) : Add(RS, p, v)), Pushup(o);
}
inline void Mul(int o, int l, int r, int L, int R) {
    if (L > R || r < L || l > R) return;
    if (L <= l && R >= r) {
        tag[o] = tag[o] * 2 % Mod;
        for (int i = 0; i <= K; i++) tr[o][i] = tr[o][i] * 2 % Mod;
        return;
    } 
    Pushdown(o), Mul(LS, L, R), Mul(RS, L, R), Pushup(o);
}
inline Dat Qry(int o, int l, int r, int L, int R) {
    Dat ans;
    if (L > R) {
        for (int i = 0; i <= K; i++) ans.a[i] = 0;
        return ans;
    }
    if (L <= l && R >= r) {
        for (int i = 0; i <= K; i++) ans.a[i] = tr[o][i];
        return ans;
    }
    Pushdown(o);
    ans.a[0] = 1;
    for (int i = 0; i <= K; i++) ans.a[i] = 0;
    if (L <= mid) {
        Dat b = Qry(LS, L, R);
        for (int i = 0; i <= K; i++) Add(ans.a[i], b.a[i]);
    }
    if (R > mid) {
        Dat b = Qry(RS, L, R);
        for (int i = 0; i <= K; i++) Add(ans.a[i], b.a[i]);
    }
    return ans;
}

signed main(void) {
    N = read(), K = read(), Lim = 2 * N;
    for (int i = 0; i <= K; i++) C[i][0] = 1;
    for (int i = 1; i <= K; i++)
        for (int j = 1; j <= i; j++) 
            C[i][j] = C[i - 1][j], Add(C[i][j], C[i - 1][j - 1]);
    for (int i = 1; i <= N; i++) p[i].fi = read(), p[i].se = read();
    sort(p + 1, p + N + 1);
    Dat v;
    for (int i = 0; i <= K; i++) v.a[i] = 0;
    v.a[0] = 1;
    Add(1, 0, Lim, 0, v);
    for (int i = 1; i <= N; i++) {
        int l = p[i].fi, r = p[i].se;
        Dat b = Qry(1, 0, Lim, 0, l - 1);
        Dat c;
        for (int i = 0; i <= K; i++) c.a[i] = 0;
        for (int i = 0; i <= K; i++) 
            for (int j = 0; j <= i; j++) Add(c.a[i], b.a[j] * C[i][j] % Mod);
        b = Qry(1, 0, Lim, l, r - 1);
        for (int i = 0; i <= K; i++) Add(c.a[i], b.a[i]);
        Add(1, 0, Lim, r, c);
        Mul(1, 0, Lim, r + 1, Lim);
    }
    printf("%lld/n", Qry(1, 0, Lim, 0, Lim).a[K]);
    return 0;
}

原创文章,作者:Maggie-Hunter,如若转载,请注明出处:https://blog.ytso.com/tech/pnotes/276535.html

(0)
上一篇 2022年7月24日 00:30
下一篇 2022年7月24日 00:39

相关推荐

发表回复

登录后才能评论