AOJ 2377 ThreeRooks

  • 問題概要:

N*Mの盤面にK個障害物があって飛車を3つ置いて互いに取り合わないようにする方法を数えろ。
N.M<=10^9,K<=10^5

  • 解法:

適当に包除を考えて自明な場合を抜くと3個の飛車がL字型に並んでるやつを数えればいいことがわかるので頑張ってsegtreeで平面走査しながら数える

  • 感想:

segtree自体は簡単だし実装重くてもどないかなるやろとか思って始めたら2時間半くらいかかった。6824B書いたし充分つらかった(デバッグ出力含)。
解法は簡単だから1100ではないと思う。

これはOI系で出すべき(確信)、あと受験生がやるものではなかった

  • コード
#include<stdio.h>
#include<vector>
#include<algorithm>
#include<map>
using namespace std;
#define SIZE 262144
typedef long long ll;
ll mod=1000000007;
typedef pair<ll,ll>pii;
class segtree
{
public:
	ll seg[SIZE*2];
	ll flag[SIZE*2];
	ll lf[SIZE*2],rf[SIZE*2];
	void init(vector<ll>vec)
	{
		for(int i=0;i<SIZE*2;i++)
		{
			seg[i]=flag[i]=lf[i]=rf[i]=0;
		}
		for(int i=0;i<vec.size()-1;i++)
		{
			lf[SIZE+i]=vec[i];
			rf[SIZE+i]=vec[i+1]-1;
		}
		for(int i=vec.size();i<SIZE;i++)
		{
			lf[SIZE+i]=rf[SIZE+i-1]+1;
			rf[SIZE+i]=rf[SIZE+i-1];
		}
		for(int i=SIZE-1;i>=1;i--)
		{
			lf[i]=lf[i*2];
			rf[i]=rf[i*2+1];
		}
	}
	void add(int beg,int end,int node,int lb,int ub,ll num)
	{
		if(ub<beg||end<lb)return;
		if(beg<=lb&&ub<=end)
		{
			seg[node]=(seg[node]+num*(rf[node]-lf[node]+1))%mod;
			flag[node]+=num;
			flag[node]%=mod;
			return;
		}
		if(flag[node]!=0)
		{
			seg[node*2]=(seg[node*2]+flag[node]*(rf[node*2]-lf[node*2]+1))%mod;
			seg[node*2+1]=(seg[node*2+1]+flag[node]*(rf[node*2+1]-lf[node*2+1]+1))%mod;
			flag[node*2]=(flag[node*2]+flag[node])%mod;
			flag[node*2+1]=(flag[node*2+1]+flag[node])%mod;
			flag[node]=0;
		}
		add(beg,end,node*2,lb,(lb+ub)/2,num);
		add(beg,end,node*2+1,(lb+ub)/2+1,ub,num);
		seg[node]=(seg[node*2]+seg[node*2+1])%mod;
	}
	ll get(int beg,int end,int node,int lb,int ub)
	{
		if(ub<beg||end<lb)return 0;
		if(beg<=lb&&ub<=end)
		{
			return seg[node];
		}
		if(flag[node]!=0)
		{
			seg[node*2]=(seg[node*2]+flag[node]*(rf[node*2]-lf[node*2]+1))%mod;
			seg[node*2+1]=(seg[node*2+1]+flag[node]*(rf[node*2+1]-lf[node*2+1]+1))%mod;
			flag[node*2]=(flag[node*2]+flag[node])%mod;
			flag[node*2+1]=(flag[node*2+1]+flag[node])%mod;
			flag[node]=0;
		}
		return (get(beg,end,node*2,lb,(lb+ub)/2)+get(beg,end,node*2+1,(lb+ub)/2+1,ub))%mod;
	}
};
segtree tree;
typedef vector<vector<ll> >vvi;
ll getnuml(ll mx,ll my,vector<pii>vec)
{
	sort(vec.begin(),vec.end());
	vector<ll>zv;
	zv.push_back(0);
	for(int i=0;i<vec.size();i++)
	{
		zv.push_back(vec[i].second);
		zv.push_back(vec[i].second+1);
	}
	zv.push_back(my);
	sort(zv.begin(),zv.end());
	vector<ll>zat;
	ll now=-1;
	for(int i=0;i<zv.size();i++)
	{
		if(now!=zv[i])
		{
			now=zv[i];
			zat.push_back(now);
		}
	}
	tree.init(zat);
	vector<ll>zx;
	for(int i=0;i<vec.size();i++)
	{
		zx.push_back(vec[i].first);
	}
	vector<ll>z;
	now=-1;
	for(int i=0;i<zx.size();i++)
	{
		if(now!=zx[i])
		{
			now=zx[i];
			z.push_back(now);
		}
	}
	zx=z;
	zx.push_back(mx);
	vec.push_back(make_pair(mx,-1));
	ll ret=0;
	ret=(((my*(my-1)/2)%mod)*((zx[0]*(zx[0]-1)/2)%mod)*2)%mod;
	tree.add(0,zat.size()-2,1,0,SIZE-1,((zx[0])*(my-1))%mod);
	vvi dat;
	dat.resize(int(zx.size()));
	int pt=0;
	for(int i=0;i<vec.size();i++)
	{
		int low=lower_bound(zat.begin(),zat.end(),vec[i].second)-zat.begin();
		if(vec[i].first!=zx[pt])pt++;
		dat[pt].push_back(low);
	}
	for(int i=0;i<zx.size()-1;i++)
	{/*
		printf("%lld:\n",zx[i]);
		for(int j=0;j<6;j++)
		{
			printf("%lld ",tree.get(j,j,1,0,SIZE-1));
		}
		printf("\n");*/
		ll sum=0;
		ll bef=0;
		dat[i].push_back(zat.size()-1);
		//for(int j=0;j<dat[i].size();j++)printf("%lld ",zat[dat[i][j]]);printf("\n");
		for(int j=0;j<dat[i].size();j++)
		{
			if(zat[dat[i][j]]!=zat[bef])
			{
				ll d=zat[dat[i][j]]-zat[bef];
				sum+=tree.get(bef,dat[i][j]-1,1,0,SIZE-1);
				tree.add(bef,dat[i][j]-1,1,0,SIZE-1,((d-1)%mod));
			}
			bef=dat[i][j]+1;
			tree.add(dat[i][j],dat[i][j],1,0,SIZE-1,(mod-tree.get(dat[i][j],dat[i][j],1,0,SIZE-1))%mod);
			//printf("sum:%lld\n",sum);
		}/*
		for(int j=0;j<6;j++)
		{
			printf("%lld ",tree.get(j,j,1,0,SIZE-1));
		}
		printf("\n");*/
		sum+=(tree.get(0,zat.size()-2,1,0,SIZE-1)*(zx[i+1]-zx[i]-1))%mod;
		sum%=mod;
		//printf("sss:%lld\n",sum);
		sum+=(((my*(my-1)/2)%mod)*(((zx[i+1]-zx[i]-1)*(zx[i+1]-zx[i]-2)/2)%mod)*2)%mod;
		sum%=mod;
		tree.add(0,zat.size()-2,1,0,SIZE-1,((zx[i+1]-zx[i]-1)*(my-1))%mod);
		/*for(int j=0;j<6;j++)
		{
			printf("%lld ",tree.get(j,j,1,0,SIZE-1));
		}
		printf("\n");*/
		ret+=sum;
		ret%=mod;
		//printf("%lld\n",sum);
	}
	return ret;
}
ll get3(ll mx,ll my,vector<pii>vec)
{
	sort(vec.begin(),vec.end());
	vector<ll>zx;
	for(int i=0;i<vec.size();i++)
	{
		zx.push_back(vec[i].first);
	}
	vector<ll>z;
	ll now=-1;
	for(int i=0;i<zx.size();i++)
	{
		if(now!=zx[i])
		{
			now=zx[i];
			z.push_back(now);
		}
	}
	zx=z;
	vvi dat;
	dat.resize(int(zx.size()));
	int pt=0;
	for(int i=0;i<vec.size();i++)
	{
		if(vec[i].first!=zx[pt])pt++;
		dat[pt].push_back(vec[i].second);
	}
	ll ret=0;
	ret=mx-zx.size();
	ret*=my;
	ret%=mod;
	ret*=(my-1);
	ret%=mod;
	ret*=(my-2);
	ret%=mod;
	ret*=(mod+1)/6;
	ret%=mod;
	for(int i=0;i<dat.size();i++)
	{
		ll sum=0;
		ll bef=0;
		dat[i].push_back(my);
		for(int j=0;j<dat[i].size();j++)
		{
			if(dat[i][j]!=bef)
			{
				ll n=1;
				n*=(dat[i][j]-bef);
				n%=mod;
				n*=(dat[i][j]-bef-1);
				n%=mod;
				n*=(dat[i][j]-bef-2);
				n%=mod;
				n*=(mod+1)/6;
				n%=mod;
				sum=(sum+n)%mod;
			}
			bef=dat[i][j]+1;
		}
		ret+=sum;
		ret%=mod;
	}
	return ret;
}
ll get2(ll mx,ll my,vector<pii>vec)
{
	sort(vec.begin(),vec.end());
	vector<ll>zx;
	for(int i=0;i<vec.size();i++)
	{
		zx.push_back(vec[i].first);
	}
	vector<ll>z;
	ll now=-1;
	for(int i=0;i<zx.size();i++)
	{
		if(now!=zx[i])
		{
			now=zx[i];
			z.push_back(now);
		}
	}
	zx=z;
	vvi dat;
	dat.resize(int(zx.size()));
	int pt=0;
	for(int i=0;i<vec.size();i++)
	{
		if(vec[i].first!=zx[pt])pt++;
		dat[pt].push_back(vec[i].second);
	}
	ll ret=0;
	ret=mx-zx.size();
	ret*=my;
	ret%=mod;
	ret*=(my-1);
	ret%=mod;
	ret*=(mod+1)/2;
	ret%=mod;
	for(int i=0;i<dat.size();i++)
	{
		ll sum=0;
		ll bef=0;
		dat[i].push_back(my);
		for(int j=0;j<dat[i].size();j++)
		{
			if(dat[i][j]!=bef)
			{
				ll n=1;
				n*=(dat[i][j]-bef);
				n%=mod;
				n*=(dat[i][j]-bef-1);
				n%=mod;
				n*=(mod+1)/2;
				n%=mod;
				sum=(sum+n)%mod;
			}
			bef=dat[i][j]+1;
		}
		ret+=sum;
		ret%=mod;
	}
	return (ret*((mx*my-vec.size()-2)%mod))%mod;
}
int main()
{
	ll mx,my,num;
	scanf("%lld%lld%lld",&mx,&my,&num);
	vector<pii>vec;
	for(int i=0;i<num;i++)
	{
		int za,zb;
		scanf("%d%d",&za,&zb);
		vec.push_back(make_pair(za,zb));
	}
	if(mx*my-num<3)
	{
		printf("0\n");
		return 0;
	}
	ll ans=1;
	ans*=(mx*my-num)%mod;
	ans%=mod;
	ans*=(mx*my-num-1)%mod;
	ans%=mod;
	ans*=(mx*my-num-2)%mod;
	ans%=mod;
	ans*=(mod+1)/6;
	ans%=mod;
	ll g21,g22,g31,g32,gl1,gl2;
	g21=get2(mx,my,vec);
	g31=get3(mx,my,vec);
	for(int i=0;i<vec.size();i++)
	{
		swap(vec[i].first,vec[i].second);
	}
	g22=get2(my,mx,vec);
	g32=get3(my,mx,vec);
	for(int i=0;i<vec.size();i++)
	{
		swap(vec[i].first,vec[i].second);
	}
	gl1=getnuml(mx,my,vec);
	for(int i=0;i<vec.size();i++)
	{
		vec[i].first=mx-1-vec[i].first;
	}
	gl2=getnuml(mx,my,vec);
	ans-=g21+g22;
	ans+=(g31+g32)*2;
	ans+=(gl1+gl2);
	ans%=mod;
	//printf("%lld %lld\n",gl1,gl2);
	printf("%lld\n",(ans+mod)%mod);
}