AOJ 2377 ThreeRooks
- 問題概要:
N*Mの盤面にK個障害物があって飛車を3つ置いて互いに取り合わないようにする方法を数えろ。
N.M<=10^9,K<=10^5
- 解法:
適当に包除を考えて自明な場合を抜くと3個の飛車がL字型に並んでるやつを数えればいいことがわかるので頑張ってsegtreeで平面走査しながら数える
- 感想:
segtree自体は簡単だし実装重くてもどないかなるやろとか思って始めたら2時間半くらいかかった。6824B書いたし充分つらかった(デバッグ出力含)。
解法は簡単だから1100ではないと思う。
これはOI系で出すべき(確信)、あと受験生がやるものではなかった
- コード
#include<stdio.h> #include<vector> #include<algorithm> #include<map> using namespace std; #define SIZE 262144 typedef long long ll; ll mod=1000000007; typedef pair<ll,ll>pii; class segtree { public: ll seg[SIZE*2]; ll flag[SIZE*2]; ll lf[SIZE*2],rf[SIZE*2]; void init(vector<ll>vec) { for(int i=0;i<SIZE*2;i++) { seg[i]=flag[i]=lf[i]=rf[i]=0; } for(int i=0;i<vec.size()-1;i++) { lf[SIZE+i]=vec[i]; rf[SIZE+i]=vec[i+1]-1; } for(int i=vec.size();i<SIZE;i++) { lf[SIZE+i]=rf[SIZE+i-1]+1; rf[SIZE+i]=rf[SIZE+i-1]; } for(int i=SIZE-1;i>=1;i--) { lf[i]=lf[i*2]; rf[i]=rf[i*2+1]; } } void add(int beg,int end,int node,int lb,int ub,ll num) { if(ub<beg||end<lb)return; if(beg<=lb&&ub<=end) { seg[node]=(seg[node]+num*(rf[node]-lf[node]+1))%mod; flag[node]+=num; flag[node]%=mod; return; } if(flag[node]!=0) { seg[node*2]=(seg[node*2]+flag[node]*(rf[node*2]-lf[node*2]+1))%mod; seg[node*2+1]=(seg[node*2+1]+flag[node]*(rf[node*2+1]-lf[node*2+1]+1))%mod; flag[node*2]=(flag[node*2]+flag[node])%mod; flag[node*2+1]=(flag[node*2+1]+flag[node])%mod; flag[node]=0; } add(beg,end,node*2,lb,(lb+ub)/2,num); add(beg,end,node*2+1,(lb+ub)/2+1,ub,num); seg[node]=(seg[node*2]+seg[node*2+1])%mod; } ll get(int beg,int end,int node,int lb,int ub) { if(ub<beg||end<lb)return 0; if(beg<=lb&&ub<=end) { return seg[node]; } if(flag[node]!=0) { seg[node*2]=(seg[node*2]+flag[node]*(rf[node*2]-lf[node*2]+1))%mod; seg[node*2+1]=(seg[node*2+1]+flag[node]*(rf[node*2+1]-lf[node*2+1]+1))%mod; flag[node*2]=(flag[node*2]+flag[node])%mod; flag[node*2+1]=(flag[node*2+1]+flag[node])%mod; flag[node]=0; } return (get(beg,end,node*2,lb,(lb+ub)/2)+get(beg,end,node*2+1,(lb+ub)/2+1,ub))%mod; } }; segtree tree; typedef vector<vector<ll> >vvi; ll getnuml(ll mx,ll my,vector<pii>vec) { sort(vec.begin(),vec.end()); vector<ll>zv; zv.push_back(0); for(int i=0;i<vec.size();i++) { zv.push_back(vec[i].second); zv.push_back(vec[i].second+1); } zv.push_back(my); sort(zv.begin(),zv.end()); vector<ll>zat; ll now=-1; for(int i=0;i<zv.size();i++) { if(now!=zv[i]) { now=zv[i]; zat.push_back(now); } } tree.init(zat); vector<ll>zx; for(int i=0;i<vec.size();i++) { zx.push_back(vec[i].first); } vector<ll>z; now=-1; for(int i=0;i<zx.size();i++) { if(now!=zx[i]) { now=zx[i]; z.push_back(now); } } zx=z; zx.push_back(mx); vec.push_back(make_pair(mx,-1)); ll ret=0; ret=(((my*(my-1)/2)%mod)*((zx[0]*(zx[0]-1)/2)%mod)*2)%mod; tree.add(0,zat.size()-2,1,0,SIZE-1,((zx[0])*(my-1))%mod); vvi dat; dat.resize(int(zx.size())); int pt=0; for(int i=0;i<vec.size();i++) { int low=lower_bound(zat.begin(),zat.end(),vec[i].second)-zat.begin(); if(vec[i].first!=zx[pt])pt++; dat[pt].push_back(low); } for(int i=0;i<zx.size()-1;i++) {/* printf("%lld:\n",zx[i]); for(int j=0;j<6;j++) { printf("%lld ",tree.get(j,j,1,0,SIZE-1)); } printf("\n");*/ ll sum=0; ll bef=0; dat[i].push_back(zat.size()-1); //for(int j=0;j<dat[i].size();j++)printf("%lld ",zat[dat[i][j]]);printf("\n"); for(int j=0;j<dat[i].size();j++) { if(zat[dat[i][j]]!=zat[bef]) { ll d=zat[dat[i][j]]-zat[bef]; sum+=tree.get(bef,dat[i][j]-1,1,0,SIZE-1); tree.add(bef,dat[i][j]-1,1,0,SIZE-1,((d-1)%mod)); } bef=dat[i][j]+1; tree.add(dat[i][j],dat[i][j],1,0,SIZE-1,(mod-tree.get(dat[i][j],dat[i][j],1,0,SIZE-1))%mod); //printf("sum:%lld\n",sum); }/* for(int j=0;j<6;j++) { printf("%lld ",tree.get(j,j,1,0,SIZE-1)); } printf("\n");*/ sum+=(tree.get(0,zat.size()-2,1,0,SIZE-1)*(zx[i+1]-zx[i]-1))%mod; sum%=mod; //printf("sss:%lld\n",sum); sum+=(((my*(my-1)/2)%mod)*(((zx[i+1]-zx[i]-1)*(zx[i+1]-zx[i]-2)/2)%mod)*2)%mod; sum%=mod; tree.add(0,zat.size()-2,1,0,SIZE-1,((zx[i+1]-zx[i]-1)*(my-1))%mod); /*for(int j=0;j<6;j++) { printf("%lld ",tree.get(j,j,1,0,SIZE-1)); } printf("\n");*/ ret+=sum; ret%=mod; //printf("%lld\n",sum); } return ret; } ll get3(ll mx,ll my,vector<pii>vec) { sort(vec.begin(),vec.end()); vector<ll>zx; for(int i=0;i<vec.size();i++) { zx.push_back(vec[i].first); } vector<ll>z; ll now=-1; for(int i=0;i<zx.size();i++) { if(now!=zx[i]) { now=zx[i]; z.push_back(now); } } zx=z; vvi dat; dat.resize(int(zx.size())); int pt=0; for(int i=0;i<vec.size();i++) { if(vec[i].first!=zx[pt])pt++; dat[pt].push_back(vec[i].second); } ll ret=0; ret=mx-zx.size(); ret*=my; ret%=mod; ret*=(my-1); ret%=mod; ret*=(my-2); ret%=mod; ret*=(mod+1)/6; ret%=mod; for(int i=0;i<dat.size();i++) { ll sum=0; ll bef=0; dat[i].push_back(my); for(int j=0;j<dat[i].size();j++) { if(dat[i][j]!=bef) { ll n=1; n*=(dat[i][j]-bef); n%=mod; n*=(dat[i][j]-bef-1); n%=mod; n*=(dat[i][j]-bef-2); n%=mod; n*=(mod+1)/6; n%=mod; sum=(sum+n)%mod; } bef=dat[i][j]+1; } ret+=sum; ret%=mod; } return ret; } ll get2(ll mx,ll my,vector<pii>vec) { sort(vec.begin(),vec.end()); vector<ll>zx; for(int i=0;i<vec.size();i++) { zx.push_back(vec[i].first); } vector<ll>z; ll now=-1; for(int i=0;i<zx.size();i++) { if(now!=zx[i]) { now=zx[i]; z.push_back(now); } } zx=z; vvi dat; dat.resize(int(zx.size())); int pt=0; for(int i=0;i<vec.size();i++) { if(vec[i].first!=zx[pt])pt++; dat[pt].push_back(vec[i].second); } ll ret=0; ret=mx-zx.size(); ret*=my; ret%=mod; ret*=(my-1); ret%=mod; ret*=(mod+1)/2; ret%=mod; for(int i=0;i<dat.size();i++) { ll sum=0; ll bef=0; dat[i].push_back(my); for(int j=0;j<dat[i].size();j++) { if(dat[i][j]!=bef) { ll n=1; n*=(dat[i][j]-bef); n%=mod; n*=(dat[i][j]-bef-1); n%=mod; n*=(mod+1)/2; n%=mod; sum=(sum+n)%mod; } bef=dat[i][j]+1; } ret+=sum; ret%=mod; } return (ret*((mx*my-vec.size()-2)%mod))%mod; } int main() { ll mx,my,num; scanf("%lld%lld%lld",&mx,&my,&num); vector<pii>vec; for(int i=0;i<num;i++) { int za,zb; scanf("%d%d",&za,&zb); vec.push_back(make_pair(za,zb)); } if(mx*my-num<3) { printf("0\n"); return 0; } ll ans=1; ans*=(mx*my-num)%mod; ans%=mod; ans*=(mx*my-num-1)%mod; ans%=mod; ans*=(mx*my-num-2)%mod; ans%=mod; ans*=(mod+1)/6; ans%=mod; ll g21,g22,g31,g32,gl1,gl2; g21=get2(mx,my,vec); g31=get3(mx,my,vec); for(int i=0;i<vec.size();i++) { swap(vec[i].first,vec[i].second); } g22=get2(my,mx,vec); g32=get3(my,mx,vec); for(int i=0;i<vec.size();i++) { swap(vec[i].first,vec[i].second); } gl1=getnuml(mx,my,vec); for(int i=0;i<vec.size();i++) { vec[i].first=mx-1-vec[i].first; } gl2=getnuml(mx,my,vec); ans-=g21+g22; ans+=(g31+g32)*2; ans+=(gl1+gl2); ans%=mod; //printf("%lld %lld\n",gl1,gl2); printf("%lld\n",(ans+mod)%mod); }