初心者の初心者による初心者のための典型segment tree

§0.はじめに

この記事は、Competitive Programming Advent Calendar Div2013の11日目の記事として書かれたものです。この記事では、列に対して変な操作をして変なクエリがくる系の問題のsegment treeによる解法を説明します。

また、「初心者の初心者による初心者のための」と銘打っていますが、

  • segment treeとはどういうものか知っている
  • 遅延更新のやり方が分かっている

の2点を仮定します。前者が分からない人は蟻本、後者が分からない人はkyuridenamidaさんのこの記事(http://d.hatena.ne.jp/kyuridenamida/20121114/1352835261)を見てください。

§1.segment treeとは何か

まずsegment treeが何ができるデータ構造なのかを示します。segment treeとはずばり

写像や値の合成を扱うデータ構造」

です。これだけだとピンとこないと思うので、一番有名と思われるRMQを用いて説明します。

RMQの各葉ノードは、xに対してmin(a,x)を出力する写像を持っていると定義できます。この写像は結合的なので、その親ノードにもmin(a,x)の形をした写像を定義できます。こうやってどんどん親ノードに写像を伝播していくと、segment treeが完成するわけです。

世の中には色々なsegment treeがありますが、これらを考えるときに「どういう写像や値を合成したいのか」を考えるのが重要になります。

§2.segment treeの考え方(遅延更新なし)

まず、遅延更新の発生しない(すなわち、更新クエリが一点に対してのみ来る)場合を考えます。

ここで考えることは、「最終的にどういう写像が欲しくて、そのためにはどのようなものを持っておく必要があるか」ということです。

では、次の問題を解いてみることにしましょう。

http://arc008.contest.atcoder.jp/tasks/arc008_4

問題概要:n個の機械が一直線上に並んでいる。この機械は、ある実数xを受け取ると(ax+b)を返す。
ある機械のaやbの変更クエリが与えられるので、n個の機械を順に通したとき値1は最終的に何になるかをクエリごとに答えなさい。


見るからにsegment treeの問題であることはわかると思います。ここで、どのような写像を作りたいかを考えます。

各機械は、(ax+b)を返すようにできているため、(ax+b)型の写像2個を合成してまた(ax+b)型の写像にできればうれしそうです。ここで、合成してみましょう。

合成する2つの写像をそれぞれf(x)=(ax+b),g(x)=(cx+d)と置くと、

g(f(x))=c(f(x))+d=c(ax+b)+d=acx+(bc+d)

となり、これは(ax+b)の形をしています。よって、この写像を上手く合成することができました。

この問題ではn<=10^12ですが、これはクエリを先読みして座標圧縮をすればよいだけなので本質ではありません。

以下コード

#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
typedef pair<double,double>pdd;
typedef long long ll;
typedef pair<ll,pdd>pidd;
typedef pair<ll,ll>pii;
pdd seg[262144];
void update(int node,pdd a)
{
	node+=131072;
	seg[node]=a;
	for(;;)
	{
		node/=2;
		if(node==0)
		{
			break;
		}
		seg[node]=make_pair(seg[node*2].first*seg[node*2+1].first,seg[node*2].second*seg[node*2+1].first+seg[node*2+1].second);
	}
}
int main()
{
	int num,query;
	scanf("%d%d",&num,&query);
	vector<pidd>vec;
	vector<ll>prs;
	for(int i=0;i<131072;i++)
	{
		update(i,make_pair(1.0,0.0));
	}
	for(int i=0;i<query;i++)
	{
		ll za;
		double zb,zc;
		scanf("%lld%lf%lf",&za,&zb,&zc);
		vec.push_back(make_pair(za,make_pair(zb,zc)));
		prs.push_back(za);
	}
	sort(prs.begin(),prs.end());
	vector<ll>uni;
	ll now=-1;
	for(int i=0;i<query;i++)
	{
		if(now!=prs[i])
		{
			uni.push_back(prs[i]);
			now=prs[i];
		}
	}
	prs=uni;
	double maxi,mini;
	maxi=mini=seg[1].first+seg[1].second;
	for(int i=0;i<query;i++)
	{
		int low=lower_bound(prs.begin(),prs.end(),vec[i].first)-prs.begin();
		update(low,vec[i].second);
		maxi=max(maxi,seg[1].first+seg[1].second);
		mini=min(mini,seg[1].first+seg[1].second);
	}
	printf("%lf\n%lf\n",mini,maxi);
}

もう1問見てみましょう。Codeforces Div1のE問題ですが、値の合成を考えるだけで簡単に解けます。(実装は少し重いですが)

http://codeforces.com/problemset/problem/295/E

問題概要:直線状に点がたくさんあります。次のクエリに高速に答えてください。

  • 指定された点をxからx+dの位置に移動する
  • 指定された区間[l,r]に対し、[l,r]間にある任意の2点の組についてのその距離の総和を求める


移動クエリは、点を消してまた作ればよさそうです。問題は2つ目のクエリです。これを考えます。

求めたいものは、「そのノードに対応する区間[l,r]に対し、[l,r]間にある任意の2点の組についてのその距離の総和」です。各ノードがこの値を持っているとき、上のノードにうまく値を伝播することを考えます。

しかしこの場合、各ノードがこの求めたい値だけを持っている場合合成が上手くいきません。ということは、この値に加えてほかの値も持たなければいけないということがわかります。逆に、この持つべき値が見抜けたならこの問題は解けたも同然です。

l<=m

#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef pair<ll,ll>pii;
typedef pair<ll,pii>pi3;
typedef pair<pii,pii>pi4;
#define SIZE (1LL<<31LL)
#define SEG 1048576
class segtree
{
public:
	ll segans[SEG*2];
	ll segl[SEG*2];
	ll segr[SEG*2];
	ll flag[SEG*2];
	ll xl[SEG*2];
	ll xr[SEG*2];
	void init(vector<ll>vec)
	{
		for(int i=0;i<vec.size()-1;i++)
		{
			xl[SEG+i]=vec[i];
			xr[SEG+i]=vec[i+1]-1;
		}
		for(int i=vec.size()-1;i<SEG;i++)
		{
			xl[SEG+i]=SIZE;
			xr[SEG+i]=SIZE-1;
		}
		for(int i=SEG-1;i>=1;i--)
		{
			xl[i]=xl[i*2];
			xr[i]=xr[i*2+1];
		}
	}
	void update(ll pl,int node,int num)
	{
		if(xr[node]<pl||pl<xl[node])
		{
			return;
		}
		if(node>=SEG)
		{
			if(num==1)
			{
				segl[node]=0;
				segr[node]=xr[node]+1-xl[node];
				flag[node]=1;
			}
			else
			{
				segl[node]=0;
				segr[node]=0;
				flag[node]=0;
			}
			return;
		}
		update(pl,node*2,num);
		update(pl,node*2+1,num);
		flag[node]=flag[node*2]+flag[node*2+1];
		segans[node]=segans[node*2]+segans[node*2+1]+segr[node*2]*flag[node*2+1]+segl[node*2+1]*flag[node*2];
		segl[node]=segl[node*2]+(xr[node*2]-xl[node*2]+1)*flag[node*2+1]+segl[node*2+1];
		segr[node]=segr[node*2]+(xr[node*2+1]-xl[node*2+1]+1)*flag[node*2]+segr[node*2+1];
	}
	pi4 calc(ll beg,ll end,int node)
	{
		if(xr[node]<beg||end<xl[node])
		{
			return make_pair(make_pair(0,0),make_pair(0,0));
		}
		if(beg<=xl[node]&&xr[node]<=end)
		{
			return make_pair(make_pair(segans[node],flag[node]),make_pair(segl[node],segr[node]));
		}
		pi4 zl=calc(beg,end,node*2);
		pi4 zr=calc(beg,end,node*2+1);
		ll ra=zl.first.first+zr.first.first+zr.second.first*zl.first.second+zl.second.second*zr.first.second;
		ll rb=zl.first.second+zr.first.second;
		ll rl=zl.second.first+zr.second.first+(xr[node*2]-xl[node*2]+1)*zr.first.second;
		ll rr=zl.second.second+zr.second.second+(xr[node*2+1]-xl[node*2+1]+1)*zl.first.second;
		return make_pair(make_pair(ra,rb),make_pair(rl,rr));
	}
};
segtree tree;
ll pla[100000];
ll zp[100000];
vector<ll>zat(vector<ll>vec)
{
	sort(vec.begin(),vec.end());
	vector<ll>ret;
	ll now=-SIZE*2;
	for(int i=0;i<vec.size();i++)
	{
		if(now!=vec[i])
		{
			now=vec[i];
			ret.push_back(now);
		}
	}
	return ret;
}
int main()
{
	int num;
	scanf("%d",&num);
	vector<ll>zv;
	vector<ll>ini;
	zv.push_back(-SIZE);
	zv.push_back(SIZE-1);
	for(int i=0;i<num;i++)
	{
		int zan;
		scanf("%d",&zan);
		pla[i]=zan;
		zp[i]=zan;
		zv.push_back(zan);
		zv.push_back(zan+1);
		ini.push_back(zan);
	}
	int query;
	scanf("%d",&query);
	vector<pi3>que;
	for(int p=0;p<query;p++)
	{
		int za;
		ll zb,zc;
		scanf("%d%I64d%I64d",&za,&zb,&zc);
		que.push_back(make_pair(za,make_pair(zb,zc)));
		if(za==1)
		{
			zb--;
			zp[zb]+=zc;
			zv.push_back(zp[zb]);
			zv.push_back(zp[zb]+1);
		}
		else
		{
			zv.push_back(zb);
			zv.push_back(zb+1);
			zv.push_back(zc);
			zv.push_back(zc+1);
		}
	}
	vector<ll>vec=zat(zv);
	tree.init(vec);
	for(int i=0;i<num;i++)
	{
		tree.update(ini[i],1,1);
	}
	for(int p=0;p<query;p++)
	{
		int za;
		ll zb,zc;
		za=que[p].first;
		zb=que[p].second.first;
		zc=que[p].second.second;
		if(za==1)
		{
			zb--;
			tree.update(pla[zb],1,-1);
			pla[zb]+=zc;
			tree.update(pla[zb],1,1);
		}
		else
		{
			pi4 zan=tree.calc(zb,zc,1);
			printf("%I64d\n",zan.first.first);
		}
	}
}

§3.segment treeの考え方(遅延更新あり)

今度は遅延更新をするバージョンのsegment treeについて扱います。すなわち、区間に対しての操作のクエリが与えられます。といっても考えることはそれほど変わりません。

遅延更新をする際に新しく考えるのは、「区間に対するクエリに対しての更新をどのように行うか」のみです。しかしこれが本質になる問題はあまり見ません(難しくても累積和で前計算した値を足すとか)。つまり本質は先ほど提示した「作りたい値をどのように合成するか」だけです。

ただし、区間に対しての操作をする場合、作りたい値をつくるために、「ノードを更新するときにその更新が定数時間でできるような値」を持っておく必要があります。有名な例として、

「ある区間にある値を加算する、この時どの点の値も負になってはならない」
「ある区間の0の個数を数える」

という2種類のクエリに答えるsegment treeを考えます。

このとき、更新クエリが一点に対してのものなら簡単で、単に各ノードにその区間にある0の個数を持っておけばいいだけです。しかし区間に対しての更新の場合、その方法ではうまくいきません(少し考えればわかるはずです)。

ここで、0は存在するなら常に区間の最小値であることに注目すると、「その区間の最小値」「その区間の最小値の個数」を持っておけばこれが計算できます。
更新クエリに対しては単にその区間の最小値をいじるだけ、またこの値の合成と伝播は左右のノードを見て最小値が小さい方の最小値の個数(同じ場合は合計)を、その区間の最小値の個数に設定すればいいだけです。

§4.動的構築

「動的」と聞いてこわいと思った方もいるかもしれませんが、これは単純です。たとえば10^9個くらいのノードを持つsegment treeが欲しいときに使えます。

このsegment treeは、アクセスされるノードがほんの一部であることを利用し、「ノードが必要になったらつくる」ということをします。実装も単純で、

「今見ているノードに左右の子がなければつくる」

ということだけを、普通のsegment treeに追加すればいいだけです。ただ、一点に対する更新も実装は根から再帰降下する感じになります。

なんのためにこのようなsegment treeがあるかというと、例えばクエリがオンラインで与えられるとき、先読みができないので座標圧縮ができません。また、座標圧縮の添え字や区間の長さを計算するが面倒なときにも使えます。普通のsegment treeに子がなければつくるという操作を付け加えただけなので、実装は単純です。

ただし、再帰の深さが深くなって少し重くなる(10^9個の要素なら30段程度のsegment treeになります)ことと、各ノードでの操作が少し多くなるのでさらに重くなること、そして左右の子へのポインタを持つ上にノードの数も増えるためMLEには十分注意が必要です。

先ほどの問題を動的構築にしたときのコードです。座標圧縮をさぼっている分実装が単純になっています。なおMLEします

#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef pair<ll,ll>pii;
typedef pair<pii,pii>pi4;
#define SIZE (1LL<<31LL)
class segtree
{
public:
	ll segans[20000000];
	ll segl[20000000];
	ll segr[20000000];
	int lko[20000000];
	int rko[20000000];
	ll flag[20000000];
	int nowpt;
	void init()
	{
		fill(lko,lko+20000000,-1);
		fill(rko,rko+20000000,-1);
		nowpt=1;
	}
	void update(ll pl,int node,ll lb,ll ub,int num)
	{
		if(lb==ub)
		{
			if(num==1)
			{
				segl[node]=0;
				segr[node]=1;
				flag[node]=1;
			}
			else
			{
				segl[node]=0;
				segr[node]=0;
				flag[node]=0;
			}
			return;
		}
		if(lko[node]==-1)
		{
			lko[node]=nowpt++;
		}
		if(rko[node]==-1)
		{
			rko[node]=nowpt++;
		}
		if(pl<=(lb+ub-1)/2)
		{
			update(pl,lko[node],lb,(lb+ub-1)/2,num);
		}
		else
		{
			update(pl,rko[node],(lb+ub-1)/2+1,ub,num);
		}
		flag[node]=flag[lko[node]]+flag[rko[node]];
		segans[node]=segans[lko[node]]+segans[rko[node]]+segr[lko[node]]*flag[rko[node]]+segl[rko[node]]*flag[lko[node]];
		segl[node]=segl[lko[node]]+((ub-lb+1)/2)*flag[rko[node]]+segl[rko[node]];
		segr[node]=segr[lko[node]]+((ub-lb+1)/2)*flag[lko[node]]+segr[rko[node]];
	}
	pi4 calc(ll beg,ll end,int node,ll lb,ll ub)
	{
		if(ub<beg||end<lb)
		{
			return make_pair(make_pair(0,0),make_pair(0,0));
		}
		if(beg<=lb&&ub<=end)
		{
			return make_pair(make_pair(segans[node],flag[node]),make_pair(segl[node],segr[node]));
		}
		if(lko[node]==-1)
		{
			lko[node]=nowpt++;
		}
		if(rko[node]==-1)
		{
			rko[node]=nowpt++;
		}
		pi4 zl=calc(beg,end,lko[node],lb,(lb+ub-1)/2);
		pi4 zr=calc(beg,end,rko[node],(lb+ub-1)/2+1,ub);
		ll ra=zl.first.first+zr.first.first+zr.second.first*zl.first.second+zl.second.second*zr.first.second;
		ll rb=zl.first.second+zr.first.second;
		ll rl=zl.second.first+zr.second.first+((ub-lb+1)/2)*zr.first.second;
		ll rr=zl.second.second+zr.second.second+((ub-lb+1)/2)*zl.first.second;
		return make_pair(make_pair(ra,rb),make_pair(rl,rr));
	}
};
segtree tree;
ll pla[100000];
int main()
{
	int num;
	scanf("%d",&num);
	tree.init();
	for(int i=0;i<num;i++)
	{
		int zan;
		scanf("%d",&zan);
		pla[i]=zan;
		tree.update(zan,0,-SIZE,SIZE-1,1);
	}
	int query;
	scanf("%d",&query);
	for(int p=0;p<query;p++)
	{
		int za;
		ll zb,zc;
		scanf("%d%I64d%I64d",&za,&zb,&zc);
		if(za==1)
		{
			zb--;
			tree.update(pla[zb],0,-SIZE,SIZE-1,-1);
			pla[zb]+=zc;
			tree.update(pla[zb],0,-SIZE,SIZE-1,1);
		}
		else
		{
			pi4 zan=tree.calc(zb,zc,0,-SIZE,SIZE-1);
			printf("%I64d\n",zan.first.first);
		}
	}
}

§5.問題集

一応僕が解いた「典型segment tree」の問題を載せておきます。CodeforcesのDiv1 Eとかしかないので解きごたえはあるでしょうが、上に書いた内容を踏まえて考えれば半ば機械的に解けるはずです。是非コンテストで典型segment treeを倒しまくってください。

以下問題(反転でヒント)
http://codeforces.com/problemset/problem/266/E 二項定理
http://codeforces.com/problemset/problem/256/E DP列を合成します
http://codeforces.com/problemset/problem/316/E3 フィボナッチ、実験しましょう フィボナッチならではの性質を使います
http://codeforces.com/problemset/problem/240/F 持つべきsegment treeは26個です
http://codeforces.com/problemset/problem/242/E やはり持つべきsegment treeは1個とは限りません xorと加算は相性が悪いので考えましょう
http://codeforces.com/problemset/problem/121/E やはりsegment treeをたくさん持ちます どうやら想定解ではないそうで定数倍がめっちゃ厳しいです
http://codeforces.com/contest/280/problem/D (実は解けなかった) ちょっと大きいオーダーなら簡単なのですが、そこからオーダーを落とすのが難しいです
他によさげな問題があったら教えてください。ここでは、列に対してへんな操作を行う系のいわゆる見た瞬間にsegment treeだとわかる系の問題しかのせていないのでそれ以外の問題が解きたい人は(http://hogloid.hatenablog.com/entry/20121227/1356608982)ここら辺を参考にしてください。