Codeforces 587F Duff is Mad

  • 問題概要

文字列がたくさんある。次のクエリを処理せよ。

クエリ: k個目文字列の中に出てくるr番目からl番目までの文字列のoccurrenceの合計を求めよ。

  • 解法

安心安全の筋肉系文字列 & 平方分割。

長い文字列に関してはSAで答えを前計算(Aho-Corasickをつかったほうが良かった気がするけどライブラリがないのでSAで殴った)して累積和の形で持っておき、短い文字列に関してはクエリを先読みしてtrie木のEuler-tour列に平方分割系のデータ構造をのせて殴る(更新がO(N)回、質問がO(Nsqrt(N))回くるときはsegtreeやBITより平方分割のほうが速い)。

CFを埋めているせいで最近こういう筋肉系ばかりやっている。このブログが実装自慢ブログと化している。

こういう実装をバグらせなくなってきたのはとてもいいこと(今回もSAのライブラリ以外に200行以上本質部分のコードを書いたのにあまりバグが生えずびっくりした)だが、頭の整理と気力が追い付かないので一息に書ききれない。

SAのせいで発生する(Aho-Corasickだとたぶん大丈夫)MLEを避けるためにバケットサイズの比がいびつになった。これなら余計にlogつけても通ったかもしれない。

#include<stdio.h>
#include<vector>
#include<algorithm>
#include<string>
#include<iostream>
using namespace std;
vector<int>sais(vector<int>vec)//1-originにして末尾に0を加えてから呼ぶのを忘れないこと
{
	if (vec.empty())
	{
		vector<int>v;
		return v;
	}
	vector<int>dat;
	dat.resize(vec.size());
	dat[vec.size() - 1] = 0;
	for (int i = vec.size() - 2; i >= 0; i--)
	{
		if (vec[i]>vec[i + 1])dat[i] = 1;
		else if (vec[i]<vec[i + 1])dat[i] = 0;
		else dat[i] = dat[i + 1];
	}
	vector<vector<int> >sa;
	int maxi = 0;
	for (int i = 0; i<vec.size(); i++)maxi = max(maxi, vec[i]);
	sa.resize(maxi + 1);
	vector<int>vnum;
	vnum.resize(maxi + 1);
	for (int i = 0; i<vec.size(); i++)vnum[vec[i]]++;
	vector<bool>islms;
	islms.resize(vec.size());
	fill(islms.begin(), islms.end(), false);
	for (int i = 0; i <= maxi; i++)
	{
		sa[i].resize(vnum[i]);
		fill(sa[i].begin(), sa[i].end(), -1);
	}
	vector<int>pt1, pt2;
	pt1.resize(maxi + 1);
	pt2.resize(maxi + 1);
	for (int i = 0; i <= maxi; i++)
	{
		pt2[i] = vnum[i] - 1;
	}
	for (int i = vec.size() - 1; i >= 1; i--)
	{
		if ((dat[i - 1] == 1 && dat[i] == 0) || i == vec.size() - 1)
		{
			sa[vec[i]][pt2[vec[i]]--] = i;
			islms[i] = true;
		}
	}
	for (int i = 0; i <= maxi; i++)
	{
		for (int j = 0; j<vnum[i]; j++)
		{
			if (sa[i][j]>0)
			{
				if (dat[sa[i][j] - 1] == 1)
				{
					sa[vec[sa[i][j] - 1]][pt1[vec[sa[i][j] - 1]]++] = sa[i][j] - 1;
				}
			}
		}
	}
	for (int i = 1; i <= maxi; i++)
	{
		for (int j = pt2[i] + 1; j<vnum[i]; j++)sa[i][j] = -1;
		pt2[i] = vnum[i] - 1;
	}
	for (int i = maxi; i >= 0; i--)
	{
		for (int j = vnum[i] - 1; j >= 0; j--)
		{
			if (sa[i][j]>0)
			{
				if (dat[sa[i][j] - 1] == 0)
				{
					sa[vec[sa[i][j] - 1]][pt2[vec[sa[i][j] - 1]]--] = sa[i][j] - 1;
				}
			}
		}
	}
	vector<int>d;
	d.resize(vec.size());
	int cnt = 0;
	vector<int>bef;
	bef.push_back(0);
	bool fl = false;
	for (int i = 0; i <= maxi; i++)
	{
		for (int j = 0; j<vnum[i]; j++)
		{
			if (islms[sa[i][j]])
			{
				vector<int>zk;
				zk.push_back(vec[sa[i][j]]);
				for (int k = sa[i][j] + 1; k<vec.size(); k++)
				{
					zk.push_back(vec[k]);
					if (islms[k])break;
				}
				if (zk != bef)cnt++;
				else if (vec[sa[i][j]] != 0)fl = true;
				d[sa[i][j]] = cnt;
				bef = zk;
			}
		}
	}
	vector<int>vt;
	for (int i = 0; i<vec.size(); i++)if (islms[i])vt.push_back(d[i]);
	vector<int>gv;
	vector<int>nv;
	if (fl)
	{
		gv = sais(vt);
		vector<int>v;
		for (int i = 0; i<vec.size(); i++)
		{
			if (islms[i])v.push_back(i);
		}
		for (int i = 0; i<gv.size(); i++)
		{
			nv.push_back(v[gv[i]]);
		}
	}
	else
	{
		gv = vt;
		nv.resize(gv.size());
		int pt = 0;
		for (int i = 0; i<gv.size(); i++)
		{
			for (;;)
			{
				if (islms[pt])
				{
					nv[gv[i]] = pt;
					pt++;
					break;
				}
				pt++;
			}
		}
	}
	for (int i = 0; i <= maxi; i++)
	{
		fill(sa[i].begin(), sa[i].end(), -1);
		pt1[i] = 0;
		pt2[i] = vnum[i] - 1;
	}
	for (int i = nv.size() - 1; i >= 0; i--)
	{
		sa[vec[nv[i]]][pt2[vec[nv[i]]]--] = nv[i];
	}
	for (int i = 0; i <= maxi; i++)
	{
		for (int j = 0; j<vnum[i]; j++)
		{
			if (sa[i][j]>0)
			{
				if (dat[sa[i][j] - 1] == 1)
				{
					sa[vec[sa[i][j] - 1]][pt1[vec[sa[i][j] - 1]]++] = sa[i][j] - 1;
				}
			}
		}
	}
	for (int i = 1; i <= maxi; i++)
	{
		for (int j = pt2[i] + 1; j<vnum[i]; j++)sa[i][j] = -1;
		pt2[i] = vnum[i] - 1;
	}
	for (int i = maxi; i >= 0; i--)
	{
		for (int j = vnum[i] - 1; j >= 0; j--)
		{
			if (sa[i][j]>0)
			{
				if (dat[sa[i][j] - 1] == 0)
				{
					sa[vec[sa[i][j] - 1]][pt2[vec[sa[i][j] - 1]]--] = sa[i][j] - 1;
				}
			}
		}
	}
	vector<int>ret;
	for (int i = 0; i <= maxi; i++)
	{
		for (int j = 0; j<vnum[i]; j++)
		{
			ret.push_back(sa[i][j]);
		}
	}
	return ret;
}
vector<int>calclcp(vector<int>str, vector<int>sa)//lcp: SAのi-1番目とi番目のlcp
{
	vector<int>rsa;
	rsa.resize(sa.size());
	for (int i = 0; i < sa.size(); i++)rsa[sa[i]] = i;
	vector<int>lcp;
	lcp.resize(sa.size());
	int now = 1;
	for (int i = 0; i < str.size() - 1; i++)
	{
		if (now != 0)now--;
		for (;;)
		{
			if (str[i + now] == str[sa[rsa[i] - 1] + now])now++;
			else
			{
				lcp[rsa[i]] = now;
				break;
			}
		}
	}
	return lcp;
}
#define SIZE 262144
class segtree
{
public:
	int seg[SIZE * 2];
	void init()
	{
		for (int i = 1; i < SIZE * 2; i++)seg[i] = 1000000000;
	}
	void update(int a, int b)
	{
		a += SIZE;
		seg[a] = min(seg[a], b);
		for (;;)
		{
			a /= 2;
			if (a == 0)break;
			seg[a] = min(seg[a * 2], seg[a * 2 + 1]);
		}
	}
	int get(int beg, int end, int node, int lb, int ub)
	{
		if (ub < beg || end < lb)return 1000000000;
		if (beg <= lb&&ub <= end)return seg[node];
		return min(get(beg, end, node * 2, lb, (lb + ub) / 2), get(beg, end, node * 2 + 1, (lb + ub) / 2 + 1, ub));
	}
};
segtree tree;
vector<int>str, sa, lcp, rsa;
void init()
{
	sa = sais(str);
	lcp = calclcp(str, sa);
	tree.init();
	rsa.resize(sa.size());
	for (int i = 0; i < sa.size(); i++)rsa[sa[i]] = i;
	for (int i = 0; i < str.size(); i++)tree.update(i, lcp[i]);
}
int getlcp(int a, int b)//a文字目からとb文字目からのlcpの長さ
{
	return tree.get(min(rsa[a], rsa[b]) + 1, max(rsa[a], rsa[b]), 1, 0, SIZE - 1);
}
#define B 1250
#define NB 81
typedef long long ll;
ll rrui[NB][200010];
ll rans[NB][100010];
int stp[100000];
int toz[100000];
vector<int>kou[100000];
ll ans[100000];
class trie
{
public:
	int nex[100001][26];
	vector<int>dat[100001];
	int pt;
	void init()
	{
		pt = 1;
		fill(nex[0], nex[0] + 26, -1);
	}
	void adds(string s, int d)
	{
		int now = 0;
		for (int i = 0; i < s.size(); i++)
		{
			if (nex[now][s[i] - 'a'] == -1)
			{
				nex[now][s[i] - 'a'] = pt;
				fill(nex[pt], nex[pt] + 26, -1);
				pt++;
			}
			now = nex[now][s[i] - 'a'];
		}
		dat[now].push_back(d);
	}
	int getdest(string s)
	{
		int now = 0;
		for (int i = 0; i < s.size(); i++)
		{
			if (nex[now][s[i] - 'a'] == -1)break;
			now = nex[now][s[i] - 'a'];
		}
		return now;
	}
	vector<int>eul;
	int ord[100001];
	int fin[100001];
	void calceul(int node)
	{
		ord[node] = eul.size();
		eul.push_back(node);
		for (int i = 0; i < 26; i++)if (nex[node][i] != -1)calceul(nex[node][i]);
		fin[node] = eul.size() - 1;
	}
};
trie tr;
class sqq
{
public:
	ll now[B*NB];
	ll flag[NB];
	void resolve(int b)
	{
		for (int i = 0; i < B; i++)now[B*b + i] += flag[b];
		flag[b] = 0;
	}
	void add(int lb, int ub, int t)
	{
		int a = lb / B, b = ub / B;
		resolve(a);
		resolve(b);
		if (a == b)
		{
			for (int i = lb; i <= ub; i++)now[i] += t;
		}
		else
		{
			for (int i = lb; i < a*B + B; i++)now[i] += t;
			for (int i = a + 1; i <= b - 1; i++)flag[i] += t;
			for (int i = b*B; i <= ub; i++)now[i] += t;
		}
	}
	ll get(ll a)
	{
		return flag[a / B] + now[a];
	}
};
sqq bi;
typedef pair<int, int>pii;
vector<pii>que1[100000], que2[100000];
int main()
{
	int num, query;
	scanf("%d%d", &num, &query);
	vector<string>vec;
	for (int i = 0; i < num; i++)
	{
		string s;
		cin >> s;
		vec.push_back(s);
	}
	for (int i = 0; i < num; i++)
	{
		stp[i] = str.size();
		for (int j = 0; j < vec[i].size(); j++)str.push_back(vec[i][j] - 'a' + 1);
		str.push_back(27);
	}
	str.push_back(0);
	init();
	tr.init();
	for (int i = 0; i < vec.size(); i++)
	{
		tr.adds(vec[i], i);
	}
	int pt = 0;
	for (int i = 0; i < num; i++)
	{
		if (vec[i].size() >= B)
		{
			for (int j = stp[i]; j < stp[i] + vec[i].size(); j++)
			{
				rrui[pt][rsa[j] + 1]++;
			}
			for (int j = 1; j <= str.size(); j++)rrui[pt][j] += rrui[pt][j - 1];
			toz[i] = pt;
			pt++;
		}
		else
		{
			for (int j = 0; j < vec[i].size(); j++)
			{
				string zs;
				for (int k = j; k < vec[i].size(); k++)zs.push_back(vec[i][k]);
				kou[i].push_back(tr.getdest(zs));
			}
		}
	}
	for (int i = 0; i < num; i++)
	{
		int lb, ub;
		int beg = 0, end = rsa[stp[i]];
		for (;;)
		{
			if (beg == end)break;
			int med = (beg + end) / 2;
			if (getlcp(sa[med], sa[rsa[stp[i]]]) < vec[i].size())beg = med + 1;
			else end = med;
		}
		lb = beg;
		beg = rsa[stp[i]], end = str.size() - 1;
		for (;;)
		{
			if (beg == end)break;
			int med = (beg + end + 1) / 2;
			if (getlcp(sa[med], sa[rsa[stp[i]]]) < vec[i].size())end = med - 1;
			else beg = med;
		}
		ub = beg;
		for (int j = 0; j < pt; j++)
		{
			rans[j][i + 1] += rrui[j][ub + 1] - rrui[j][lb];
		}
	}
	for (int i = 0; i < pt; i++)
	{
		for (int j = 0; j < num; j++)
		{
			rans[i][j + 1] += rans[i][j];
		}
	}
	for (int p = 0; p < query; p++)
	{
		int za, zb, zc;
		scanf("%d%d%d", &za, &zb, &zc);
		za--;
		zb--;
		zc--;
		if (vec[zc].size() >= B)
		{
			ans[p] = rans[toz[zc]][zb + 1] - rans[toz[zc]][za];
		}
		else
		{
			que1[za].push_back(make_pair(zc, p));
			que2[zb].push_back(make_pair(zc, p));
		}
	}
	tr.calceul(0);
	for (int i = 0; i < num; i++)
	{
		for (int j = 0; j < que1[i].size(); j++)
		{
			for (int k = 0; k < kou[que1[i][j].first].size(); k++)
			{
				ans[que1[i][j].second] -= bi.get(tr.ord[kou[que1[i][j].first][k]]);
			}
		}
		if (vec[i].size() < B)bi.add(tr.ord[kou[i][0]], tr.fin[kou[i][0]], 1);
		for (int j = 0; j < que2[i].size(); j++)
		{
			for (int k = 0; k < kou[que2[i][j].first].size(); k++)
			{
				ans[que2[i][j].second] += bi.get(tr.ord[kou[que2[i][j].first][k]]);
			}
		}
	}
	for (int i = 0; i < query; i++)
	{
		printf("%I64d\n", ans[i]);
	}
}