Uoj#348.【WC2018】州区划分 | wzf2000's blog

Uoj#348.【WC2018】州区划分

Uoj#348.【WC2018】州区划分 题解

题意:

  • 让你将$n$个点分为有顺序的若干的点集,保证每个点集不构成欧拉图。
  • 询问划分方案的每个点集的满意值的乘积的和,每个点集的满意度为这个点集点权和跟在其之前(包括本身)的点集的点权和的比的$p$次方,即
  • $1\le n\le 21,0\le p\le 2$。

题解:

  • 作为参加过$\rm WC$的人(菜),我在跟出题人心灵相通(雾)的情况下,还是只写了一个$50$分暴力。
  • 讲道理这道题的$p\le 2$并没有什么用处。
  • 似乎大家都不相信这么简单于是都没有做出来。
  • 显然我们可以非常轻松地预处理出,哪些点集是合法的。
  • 然后我们列出一个转移方程:
  • 稍微化简一下:
  • 显然互不相交的子集卷积还是比较难处理的,我们考虑加上一维点数来限制,然后只需要取并即可。
  • 具体就是,$dp[i][j]$代表有$i$个点,点集为$j$的满意度和。
  • 转移就变为:
  • 最后再将对应的$dp[i]$出掉相应的$sum$即可。
  • 我们发现$dp[i]$转移就相当于跟$dp[1\sim i-1]$和$g[i][j]=[bitcount(j)=i]sum_j^p[check(j)]$的集合并卷积。
  • 所以如果我们暴力使用$\rm FMT$,就将进行$n^2$次,也就是复杂度$O(2^nn^3)$。
  • 但是如果我们对每个$dp[i],g[i]$只进行一次$\rm FMT$,在$dp[i]$刚算出时进行一次$\rm IFMT$,这样复杂度就降到了$O(2^nn^2)$。
  • 然而我竟然因为$O(2^nn)$次的快速幂而$\rm TLE$了。。。

代码:

#include <bits/stdc++.h>
#define gc getchar()
#define ll long long
#define N 22
#define mod 998244353
#define clz(x) __builtin_popcount(x)
using namespace std;
int n,m,p,mp[N][N],d[N],h[1<<(N-1)],fa[N],w[N],g[N][1<<(N-1)],dp[N][1<<(N-1)];
int read()
{
    int x=1;
    char ch;
    while (ch=gc,ch<'0'||ch>'9') if (ch=='-') x=-1;
    int s=ch-'0';
    while (ch=gc,ch>='0'&&ch<='9') s=s*10+ch-'0';
    return s*x;
}
int ksm(int x,int y,int ret=1)
{
    for (;y;y>>=1,x=(ll)x*x%mod) if (y&1) ret=(ll)ret*x%mod;
    return ret;
}
int find(int x)
{
    return fa[x]==x?x:fa[x]=find(fa[x]);
}
bool check(int x)
{
    for (int i=1;i<=n;i++) d[i]=0,fa[i]=i;
    for (int i=1;i<n;i++)
        if (x>>(i-1)&1)
            for (int j=i+1;j<=n;j++)
                if (x>>(j-1)&1&&mp[i][j]) d[i]++,d[j]++,fa[find(i)]=find(j);
    int anc=0;
    for (int i=1;i<=n;i++)
        if (x>>(i-1)&1)
        {
            if (d[i]&1) return 1;
            if (find(i)!=i) continue;
            if (!anc) anc=i;
            else if (i!=anc) return 1;
        }
    return 0;
}
int get(int x)
{
    if (!p) return 1;
    int ret=0;
    for (int i=1;i<=n;i++) if (x>>(i-1)&1) ret+=w[i];
    if (p==1) return ret;
    else return ret*ret;
}
void fmt(int *a,int flag)
{
    for (int i=0;i<n;i++)
        for (int j=0;j<(1<<n);j++) if (j>>i&1) a[j]=(a[j]+(ll)(mod+flag)*a[j^(1<<i)])%mod;
}
int main()
{
    n=read(),m=read(),p=read();
    for (int i=1;i<=m;i++)
    {
        int x=read(),y=read();
        mp[x][y]=mp[y][x]=1;
    }
    for (int i=1;i<=n;i++) w[i]=read();
    for (int i=0;i<(1<<n);i++)
    {
        h[i]=get(i);
        if (check(i)) g[clz(i)][i]=h[i];
        h[i]=ksm(h[i],mod-2);
    }
    for (int i=1;i<=n;i++) fmt(g[i],1);
    dp[0][0]=1,fmt(dp[0],1);
    for (int i=1;i<=n;i++)
    {
        for (int j=0;j<i;j++)
            for (int k=0;k<(1<<n);k++) dp[i][k]=(dp[i][k]+(ll)dp[j][k]*g[i-j][k])%mod;
        fmt(dp[i],-1);
        for (int k=1;k<(1<<n);k++) dp[i][k]=(ll)dp[i][k]*h[k]%mod;
        fmt(dp[i],1);
    }
    fmt(dp[n],-1);
    printf("%d\n",dp[n][(1<<n)-1]);
    return 0;
}