SRM710 Div1 hard Hyperboxes

  • 問題概要

{1,2,...,N}^K の点を頂点とした K 次元超直方体を重ならないように M 個置く方法は何通りあるか。

  • 解法

超直方体同士が重ならない <-> どこかの次元でその超直方体が占める区間がdisjoint なので、1次元の場合に区間がどういう重なり方(M <= 6 個の区間のそれぞれのペアに対し重なるか重ならないかの高々 2^15 通り)をするものが何通りあるかを前計算した後、例の A and B に足す convolution で二分累乗する。
後半パートは 2^15 * 15 * log 10^9 となる。

後半パートが本質だと思ってなめてかかったら前半パートが計算量のよくわからないいくつかの方針の中で正しいものを選ばないといけないゲームで、そっちに手間取った。
迷走した痕跡を敢えて残してコードを貼っておきます。

#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
ll mod=998244353;
#define LOG 15
class Andadd
{
public:
	ll a[LOG+1][1<<LOG],b[LOG+1][1<<LOG],ret[LOG+1][1<<LOG];
	vector<ll>conv(vector<ll>va,vector<ll>vb)
	{
		for(int i=0;i<va.size();i++)a[0][i]=va[i];
		for(int i=0;i<vb.size();i++)b[0][i]=vb[i];
		for(int i=0;i<LOG;i++)
		{
			for(int j=0;j<(1<<LOG);j+=(1<<(LOG-i)))
			{
				int s=1<<(LOG-i-1);
				for(int k=0;k<s;k++)a[i+1][j+k]=(a[i][j+k]+a[i][j+s+k])%mod,a[i+1][j+s+k]=a[i][j+s+k];
				for(int k=0;k<s;k++)b[i+1][j+k]=(b[i][j+k]+b[i][j+s+k])%mod,b[i+1][j+s+k]=b[i][j+s+k];
			}
		}
		for(int i=0;i<(1<<LOG);i++)ret[LOG][i]=a[LOG][i]*b[LOG][i]%mod;
		for(int i=LOG-1;i>=0;i--)
		{
			for(int j=0;j<(1<<LOG);j+=(1<<(LOG-i)))
			{
				int s=1<<(LOG-i-1);
				for(int k=0;k<s;k++)ret[i][j+k]=(ret[i+1][j+k]+mod-ret[i+1][j+s+k])%mod;
				for(int k=0;k<s;k++)ret[i][j+s+k]=ret[i+1][j+s+k];
			}
		}
		vector<ll>r;
		for(int i=0;i<(1<<LOG);i++)r.push_back(ret[0][i]);
		return r;
	}
};
ll po(ll a, ll b)
{
	if (b == 0)return 1;
	ll z = po(a, b / 2);
	z = z*z%mod;
	if (b & 1)z = z*a%mod;
	return z;
}
ll iiiinv(ll a)
{
	return po(a, mod - 2);
}
ll inv[100];
ll com(ll a,ll b)
{
	if(b<0||b>a)return 0;
	ll r=1;
	for(int i=1;i<=b;i++)r=r*(a-i+1)%mod*inv[i]%mod;
	return r;
}
ll dat[1<<15];
Andadd ad;
vector<ll>ppooww(vector<ll>a,ll b)
{
	if(b==1)return a;
	vector<ll>z=ppooww(a,b/2);
	z=ad.conv(z,z);
	if(b%2==1)z=ad.conv(z,a);
	return z;
}
#include<map>
map<ll,ll>dp[13];
typedef pair<ll,ll>pii;
int dig[6][6];
class Hyperboxes
{
public:
	int findCount(int len,int num,int dim)
	{
		for(int i=1;i<100;i++)inv[i]=iiiinv(i);
		/*vector<int>v;
		for(int i=0;i<num;i++)v.push_back(i+1),v.push_back(i+1);
		for(;;)
		{
			int now=0;
			bool fff=true;
			for(int i=0;i<v.size();i++)
			{
				if(now+1<v[i])fff=false;
				now=max(now,v[i]);
			}
		//	if(fff)
			{
				int d[10];
				fill(d,d+10,0);
				int st[10],go[10];
				vector<int>x;
				for(int i=0;i<v.size();i++)
				{
					if(d[v[i]]==0)d[v[i]]=1,x.push_back(v[i]),st[v[i]]=i;
					else x.push_back(-v[i]),go[v[i]]=i;
				}
				ll dp[13][13][13];
				for(int i=0;i<13;i++)for(int j=0;j<13;j++)for(int k=0;k<13;k++)dp[i][j][k]=0;
				dp[0][0][1]=1;
				for(int i=1;i<x.size();i++)
				{
					int t=0;
					if(x[i-1]<0&&x[i]>0)t=1;
					else if(x[i-1]>0&&x[i]>0&&v[i-1]<v[i])t=1;
					else if(x[i-1]<0&&x[i]<0&&v[i-1]<v[i])t=1;
					if(t==0)
					{
						for(int k=0;k<=num+num;k++)
						{
							for(int j=i-1;j>=0;j--)
							{
								if(v[i]==v[j])break;
								dp[i][j][k]+=dp[i-1][j][k];
							}
						}
					}
					for(int k=0;k<num+num;k++)
					{
						for(int j=0;j<i;j++)
						{
							dp[i][i][k+1]+=dp[i-1][j][k];
						}
					}
				}
				ll sum=0;
				for(int k=0;k<=num+num;k++)
				{
					ll ss=0;
					for(int j=0;j<=num+num;j++)ss+=dp[num+num-1][j][k];
					//printf(" %lld\n",ss);
					sum+=ss*com(len,k);
					sum%=mod;
				}
				int mask=0;
				int pt=0;
				for(int i=1;i<=num;i++)
				{
					for(int j=i+1;j<=num;j++)
					{
						if(!(go[i]<st[j]||go[j]<st[i]))mask+=(1<<pt);
						pt++;
					}
				}
				dat[mask]=(dat[mask]+sum)%mod;
				//printf("%d %lld\n",mask,sum);
			}
			if(!next_permutation(v.begin(),v.end()))break;
		}
	//	for(int i=0;i<2;i++)printf("%lld\n",dat[i]);*/
		int pppt=0;
		for(int i=0;i<num;i++)
		{
			for(int j=i+1;j<num;j++)
			{
				dig[i][j]=dig[j][i]=pppt;
				pppt++;
			}
		}
		dp[0][0]=1;
		for(int i=0;i<=num+num;i++)
		{
			map<ll,ll>::iterator it=dp[i].begin();
			for(;;)
			{
				if(it==dp[i].end())break;
				pii zz=*it;
				it++;
			//	printf("%d %lld %lld %lld\n",i,zz.first>>15,zz.first%(1<<15),zz.second);
				int x[6],mask=zz.first%(1<<15);
				zz.first/=(1<<15);
				for(int j=0;j<num;j++)x[j]=zz.first%3,zz.first/=3;
				vector<int>v0,v1;
				for(int j=0;j<num;j++)
				{
					if(x[j]==0)v0.push_back(j);
					else if(x[j]==1)v1.push_back(j);
				}
				if(v0.size()+v1.size()==0)
				{
					dat[mask]+=zz.second*com(len,i);
					dat[mask]%=mod;
					continue;
				}
				for(int p=0;p<(1<<v0.size());p++)
				{
					int y[6];
					int m=mask;
					for(int j=0;j<num;j++)y[j]=x[j];
					for(int j=0;j<v0.size();j++)
					{
						if(p&(1<<j))
						{
							for(int k=0;k<num;k++)if(y[k]==1)m|=1<<dig[v0[j]][k];
							y[v0[j]]=1;
						}
					}
					for(int q=0;q<(1<<v1.size());q++)
					{
						if(p==0&&q==0)continue;
						int z[6];
						for(int j=0;j<num;j++)z[j]=y[j];
						for(int j=0;j<v1.size();j++)if(q&(1<<j))z[v1[j]]=2;
						ll t=0;
						for(int j=num-1;j>=0;j--)t*=3,t+=z[j];
						t=(t<<15)+m;
						dp[i+1][t]=(dp[i+1][t]+zz.second)%mod;
					}
				}
			}
		}
		vector<ll>ans;
	//	for(int i=0;i<4;i++)printf("%lld\n",dat[i]);
		for(int i=0;i<(1<<15);i++)ans.push_back(dat[i]);
		vector<ll>r=ppooww(ans,dim);
		ll aaa=r[0];
	//	for(int i=1;i<=num;i++)aaa=aaa*i%mod;
		return int(aaa);
	}
};