可持久化

变简单了?!

Posted by JU on August 6, 2018

可持久化线段树

话说线段树大家应该都听说过,然而加上了“可持久化”四个大字后有什么区别呢??

概念

我们知道,每次单点修改,都需要更新约logn个线段树上的节点。   而可持久化线段树与普通线段树最大的区别在于,每次要修改一个数据时,不在原线段树上修改,而是在受到影响的节点“旁边”新建一个节点,上面记录新的数据。
而这些新建的节点与不受影响的节点,构成了一个历史版本。 同时,不再使用普通线段树的节点编号方式,而是根据新建节点的顺序进行编号,节约空间。这样,就需要额外记录每个节点的两个儿子的编号。
当然,每个版本的根节点额外记录一下。

模板题 luoguP3834

看到这题,就连傻子也知道要用可持久化线段树来做……(标题写着呢)  

思路

回到本题,首先依次插入序列中的每个数的排名,这样就得到了n个历史版本(类似前缀和的东西)。   每次询问一个区间,就调出第l-1与第r个历史版本,将其对应的节点相减,即可得知所需的信息。
当然,不是每个点都相减,而只是相关点。
最后进行一些细节处理即可。

#include <bits/stdc++.h>
#define sc(p) scanf("%d",&p)
using namespace std;
const int N=200100;
struct LGJ
{
    int id,pm,num;
}a[N];
bool LGJ1 (LGJ q1,LGJ q2) { return q1.num<q2.num; }
bool LGJ2 (LGJ q1,LGJ q2) { return q1.id<q2.id; }
struct TREEXL
{
    int sum,ls,rs;
    #define ls(p) t[p].ls
    #define rs(p) t[p].rs
    #define sum(p) t[p].sum
}t[N*20];
int n,m;
int tot=0,root[N]={0};
int rea[N];
int build (int l,int r)
{
    int p=++tot;
    if (l<r)
    {
	int mid=(l+r)/2;
	ls(p)=build (l,mid);
	rs(p)=build (mid+1,r);
    }
    return p;
}
int insert (int old,int l,int r,int x)
{
    int p=++tot;
    ls(p)=ls(old),rs(p)=rs(old),sum(p)=sum(old)+1;
    if (l<r)
    {
	int mid=(l+r)/2;
	if (x<=mid) ls(p)=insert (ls(old),l,mid,x);
	else rs(p)=insert (rs(old),mid+1,r,x);
    }
    return p;
}
int search (int pre,int aft,int l,int r,int k)
{
    if (l==r)  return l;
    int x=sum(ls(aft))-sum(ls(pre));
    int mid=(l+r)/2;
    if (x>=k) return search (ls(pre),ls(aft),l,mid,k);
    else return search (rs(pre),rs(aft),mid+1,r,k-x);
}
int b[N];
int main()
{
    //freopen ("data.out","r",stdin);
    //freopen ("mine.out","w",stdout);
    sc(n),sc(m);
    for (int i=1; i<=n; i++) sc(a[i].num),a[i].id=i;

    sort (a+1,a+n+1,LGJ1);
    int sum=0;
    for (int i=1; i<=n; i++)
    {
	if (i==1||a[i].num!=a[i-1].num) a[i].pm=++sum;
	else a[i].pm=sum;
	b[a[i].pm]=a[i].num;
    }
    sort (a+1,a+n+1,LGJ2);

    root[0]=build (1,sum);
    for (int i=1; i<=n; i++)
    root[i]=insert (root[i-1],1,sum,a[i].pm);

    for (int i=1; i<=m; i++)
    {
	int x,y,k; sc(x),sc(y),sc(k);
	printf ("%d\n",b[search (root[x-1],root[y],1,sum,k)]);
    }
    return 0;
} ### 区间修改 HDU4348

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <iostream>
#define sc(p) scanf("%d",&p)
#define sc2(p1,p2) sc(p1),sc(p2)
#define pr(p) printf("%lld\n",p)
const int oo=2147483640;
const int N=101000;
using namespace std;
int a[N];
struct TREEXL
{
    long long sum,add;int ls,rs;
    #define ls(p) t[p].ls
    #define rs(p) t[p].rs
    #define sum(p) t[p].sum
    #define add(p) t[p].add
}t[N*25];
long long n,m;
int tot=0,root[N]={0};
int build (int l,int r)
{
    int p=++tot;
    if (l==r) { sum(p)=a[l]; return p; }
    if (l<r)
    {
	int mid=(l+r)/2;
	ls(p)=build (l,mid);
	rs(p)=build (mid+1,r);
	sum(p)=sum(ls(p))+sum(rs(p));
    }
    return p;
}
int change (int old,int L,int R,int l,int r,long long x)
{
	int p=++tot;
	ls(p)=ls(old),rs(p)=rs(old),sum(p)=sum(old),add(p)=add(old);
	if (l<=L&&R<=r) {
	sum(p)+=(R-L+1)*x,add(p)+=x;
	return p; }
	long long mid=(L+R)>>1;
	if (l<=mid) ls(p)=change (ls(old),L,mid,l,r,x);
	if (r>mid) rs(p)=change (rs(old),mid+1,R,l,r,x);
	sum(p)=sum(ls(p))+sum(rs(p))+add(p)*(R-L+1);
	return p;
}
long long query (int p,int L,int R,int l,int r,long long addv)
{
	long long ans=0;
	if (l<=L&&R<=r) return sum(p)+addv*(R-L+1);
	long long mid=(L+R)>>1; addv+=add(p);
	if (l<=mid) ans+=query (ls(p),L,mid,l,r,addv);
	if (r>mid) ans+=query (rs(p),mid+1,R,l,r,addv);
	return ans;
}
int main()
{
	int n,m;
    while (sc(n)!=EOF)
    {
	tot=0;
	memset (root,0,sizeof (root));
	memset (t,0,sizeof (t));
	memset (a,0,sizeof (a));
	sc(m);
		for (long long i=1; i<=n; i++) sc(a[i]);
		root[0]=build (1,n);
		long long cur=0;
		for (long long i=1; i<=m; i++)
		{
			char c[2]; int l,r,t; scanf ("%s",c);
			if (c[0]=='C') { sc2(l,r),sc(t);
			root[cur+1]=change (root[cur],1,n,l,r,t);
			cur++; }
			if (c[0]=='Q') { sc2(l,r);
			long long g=query (root[cur],1,n,l,r,0);
			pr(g); }
			if (c[0]=='H') { sc2(l,r),sc(t);
			long long g=query (root[t],1,n,l,r,0);
			pr(g);
			}
			if (c[0]=='B') { sc(cur); }
		}
	}
    return 0;
}

可持久化Trie

更好例题

HDU4757

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <queue>
#define sc(p) scanf("%d",&p)
#define sc2(p1,p2) sc(p1),sc(p2)
#define pr(p) printf("%d\n",p)
const int oo=2147483640;
const int N=100010;
using namespace std;
inline void read (int &x)
{
    x=0; static int p; p=1; static char c; c=getchar();
    while (!isdigit(c)) { if (c=='-') p=-1; c=getchar(); }
    while ( isdigit(c)) { x=(x<<1)+(x<<3)+(c-48); c=getchar(); }
    x*=p;
}
struct EDGE { int v,nx; }lb[N<<1]; int tot=1,top[N];
void add (int u,int v) { lb[++tot].v=v,lb[tot].nx=top[u],top[u]=tot; }
queue <int> q; int f[N][50],d[N],t;
void bfs ()
{
    while (q.size()) q.pop();
    q.push (1); d[1]=1;
    while (q.size())
    {
	int u=q.front(); q.pop();
	for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
	{
	    int v=lb[kb].v;
	    if (d[v]) continue;
	    d[v]=d[u]+1;
	    f[v][0]=u;
	    for (int j=1; j<=t; j++)
	    f[v][j]=f[f[v][j-1]][j-1];
	    q.push(v);
	}
    }
}
int lca (int x,int y)
{
    if (d[x]>d[y]) swap (x,y);
    for (int i=t; i>=0; i--)
    if (d[f[y][i]]>=d[x]) y=f[y][i];
    if (x==y) return x;
    for (int i=t; i>=0; i--)
    if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
struct TRIE { int ls,rs,s; }trie[N*60]; int sum=0;
#define ls(p) trie[p].ls
#define rs(p) trie[p].rs
#define s(p) trie[p].s
int val[N];
int root[N]; bool vis[N]={0};
int insert (int p,int va,int wei)
{
    int cur=++sum;
    s(cur)=s(p)+1;
    if (wei==-1) return cur;
    ls(cur)=ls(p),rs(cur)=rs(p);
    if ((1<<wei)&va) rs(cur)=insert (rs(p),va,wei-1);
    else ls(cur)=insert (ls(p),va,wei-1);
    return cur;
}
void build (int u,int fa)
{
    vis[u]=1;
    root[u]=insert (root[fa],val[u],t);
    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
	int v=lb[kb].v;
	if (vis[v]) continue;
	build (v,u);
    }
    return ;
}
int mi2[20];
int find (int pre,int aft,int va,int wei)
{
    int xl=s(ls(aft))-s(ls(pre)),xr=s(rs(aft))-s(rs(pre));
    if (wei==-1) return 0;
    if ((1<<wei)&va)
    {
	if (xl) return mi2[wei]+find (ls(pre),ls(aft),va,wei-1);
	else return find (rs(pre),rs(aft),va,wei-1);
    }
    else
    {
	if (xr) return mi2[wei]+find (rs(pre),rs(aft),va,wei-1);
	else return find (ls(pre),ls(aft),va,wei-1);
    }
}
int main()
{
    mi2[0]=1;
    for (int i=1; i<=18; i++) mi2[i]=mi2[i-1]*2;
    t=15;
    int n,m;
    while (sc2(n,m)!=EOF)
    {
	memset (top,-1,sizeof (top));
	memset (d,0,sizeof (d));
	memset (f,0,sizeof (f));
	memset (root,0,sizeof (root));
	memset (vis,0,sizeof (vis));
	memset (trie,0,sizeof (trie));
	memset (val,0,sizeof (val));
	tot=1;
	for (int i=1; i<=n; i++) sc(val[i]);
	for (int i=1; i< n; i++)
	{
	    int x,y; sc2(x,y);
	    add (x,y),add (y,x);
	}
	build (1,0);
	bfs ();
	for (int i=1; i<=m; i++)
	{
	    int x,y,k; sc2(x,y),sc(k);
	    int z=lca (x,y);
	    int ans=val[z]^k;
	    int g=find (root[z],root[x],k,t);
	    ans=max (ans,g);
	    g=find (root[z],root[y],k,t);
	    ans=max (ans,g);
	    pr(ans);
	}
    }

    return 0;
}

例题:SMOJ2505

用于搞笑。

思路

等等,不是说处理字符串吗?!咋变成数了!?没事,注意到xor运算,于是应想到将其二进制按位分解。而能够储存这些二进制形式的最佳数据结构,莫过于Trie。(虽然比赛时完全没想到)
而数据又有一个奇怪之处:n<=1000,m<=300000,这明显暗示处理行时暴力,处理列时用特殊方法。  

做法

于是构建一棵可持久化Trie,依次将给出的Y[]插入(从高位到低位),并记录root[];同时,记录当前节点的路径是多少个数(二进制)的前缀。  

bool check (int a,int w)
{
    return (a&(1<<(w-1)));
}
int L[N],R[N];
void insert (int w,int old,int &now,int v)
{
    now=++tot;
    ls(now)=ls(old),rs(now)=rs(old);
    cnt(now)=cnt(old)+1;
    if (w==0) return ;
    if (check (v,w)) 
        insert (w-1,rs(old),rs(now),v);
    else 
        insert (w-1,ls(old),ls(now),v);
}

每次读入一个询问,就调出区间[l,r]的信息L[]R[](实际上只是调出root[l-1]root[r]),进行查询。

1.为方便处理,首先将“第k大”转变为”第((矩形大小)-k+1)小“(改变k)。

2.由Tire从高位到低位插入易知,当某行的w位(二进制)为0/1时,将Tire上当前节点的leftson(0在左)/rightson的数量累加起来,由于当前位相同,异或的结果也最小,由得到的结果sum可知矩形内前sum小的数的信息(如在哪个儿子内)。

int sum=0;
for (int i=u; i<=d; i++)
	if (check (x[i],w)) 
		sum+=cnt(rs(R[i]))-cnt(rs(L[i]));
	else 
		sum+=cnt(ls(R[i]))-cnt(ls(L[i]));

3.由此可以分类讨论:
1)sum<k。此时所有L[]R[]都移向能使sum增大的儿子处,并使k减小sum。同时由于返回的是最终答案(一个值,而非一个位置),应加上2^(??)。这样做是因为由于移向了异或后的结果为1的位置,而除去了当前这一位,要在答案中补上。

2)sum>k。同样将LR移向使sum减小的儿子。此时由于移向异或结果为0的儿子,这一位对答案无影响,无需加2^(??)。

if (sum<k)
{	
	for (int i=u; i<=d; i++)
		if (check (x[i],w)) 
			L[i]=ls(L[i]),R[i]=ls(R[i]); 
		else 
			L[i]=rs(L[i]),R[i]=rs(R[i]); 

	return query (w-1,k-sum,u,d)+(1<<(w-1));
}
else
{
	for (int i=u; i<=d; i++)
		if (check (x[i],w)) 
			L[i]=rs(L[i]),R[i]=rs(R[i]);
	else 
		L[i]=ls(L[i]),R[i]=ls(R[i]);

	return query (w-1,k,u,d);
}

可持久化Treap

模板LuoguP3835

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <iostream>
#define sc(p) scanf("%d",&p)
#define sc2(p1,p2) sc(p1),sc(p2)
#define pr(p) printf("%d\n",p)
const int oo=2147483640;
const int N=1010000;
using namespace std;
struct TREAP { int ls,rs,siz,v,w,cnt; }t[N]; int tot=0,root[N];
#define ls(p) t[p].ls
#define rs(p) t[p].rs
#define siz(p) t[p].siz
#define v(p) t[p].v
#define w(p) t[p].w
#define cnt(p) t[p].cnt
inline void read (int &x)
{
	x=0; static int p; p=1; static char c; c=getchar();
	while (!isdigit(c)) { if (c=='-') p=-1; c=getchar(); }
	while ( isdigit(c)) { x=(x<<1)+(x<<3)+(c-48); c=getchar(); }
	x*=p;
}
int copy (int x) { int y=++tot; t[y]=t[x]; return y; }
void update (int x) { if (x) siz(x)=siz(ls(x))+cnt(x)+siz(rs(x)); }
int newnode (int val)
{
	tot++;
	ls(tot)=rs(tot)=0;
	v(tot)=val;
	siz(tot)=cnt(tot)=1;
	w(tot)=rand();
	return tot;
}
int merge (int a,int b)
{
	if (!a||!b) return a+b;
	int x;
	if (w(a)>w(b)) x=copy (a),rs(x)=merge (rs(a),b);
	else           x=copy (b),ls(x)=merge (a,ls(b));
	update (x);
	return x;
}
void split (int x,int k,int &a,int &b)
{
	if (!x) { a=b=0; return ; }
	if (v(x)==k)
	{
		if (ls(x)) a=copy (ls(x)); else a=0;
		if (rs(x)) b=copy (rs(x)); else b=0;
	}
	if (v(x)<k)
	{
		a=copy (x);
		split (rs(x),k,rs(a),b);
	}
	if (v(x)>k)
	{
		b=copy (x);
		split (ls(x),k,a,ls(b));
	}
	update (a),update (b);
}
int find (int x,int v)
{
	if (!x) return 0;
	if (v==v(x)) return x;
	return ( (v<v(x)) ? find (ls(x),v) : find (rs(x),v) );
}
void insert (int x,int k,int &u)
{
	int pos=find (x,k),a,b;
	split (x,k,a,b);
	int p=newnode (k);
	if (pos) siz(p)=cnt(p)=cnt(pos)+1;
	u=merge (merge (a,p),b);
}
void remove (int x,int k,int &u)
{
	int pos=find (x,k),a,b;
	if (!pos) return ;
	if (cnt(pos)>1)
	{
		int p=newnode (v(pos));
		siz(p)=cnt(p)=cnt(pos)-1;
		split (x,k,a,b);
		u=merge (merge (a,pos),b);
	}
	else
	{
		split (x,k,a,b);
		u=merge (a,b);
	}
}
int getpre (int x,int v)
{
	if (!x) return 0;
	if (v(x)<v)
	{
		int p=getpre (rs(x),v);
		return (!p) ? x : p;
	}
	else return getpre (ls(x),v);
}
int getnex (int x,int v)
{
	if (!x) return 0;
	if (v(x)>v)
	{
		int p=getnex (ls(x),v);
		return (!p) ? x : p;
	}
	else return getnex (rs(x),v);
}
int getrank (int x,int v)
{
	if (!x) return 0;
	if (v==v(x)) return siz(ls(x));
	if (v< v(x)) return getrank (ls(x),v);
	else return getrank (rs(x),v)+siz(ls(x))+cnt(x);
}
int getkth (int x,int k)
{
	if (siz(ls(x))+1<=k&&k<=siz(ls(x))+cnt(x)) return x;
	if (k<=siz(ls(x))) return getkth (ls(x),k);
	else return getkth (rs(x),k-siz(ls(x))-cnt(x));
}
int main()
{
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
	int n; sc(n);
	int st=newnode (-oo),ed=newnode (oo);
	(w(st)>w(ed)) ? rs(st)=ed : ls(ed)=st;
	root[0]= (w(st)>w(ed)) ? st : ed;
	for (int i=1; i<=n; i++)
	{
		int u,op,x; sc2(u,op),sc(x);
		root[i]=root[u];
		if (op==1) insert (root[u],x,root[i]);
		if (op==2) remove (root[u],x,root[i]);
		if (op==3) pr(getrank (root[u],x));
		if (op==4) pr(v(getkth (root[u],x+1)));
		if (op==5) pr(v(getpre (root[u],x)));
		if (op==6) pr(v(getnex (root[u],x)));
	}
    return 0;
}
CC 原创文章采用CC BY-NC-SA 4.0协议进行许可,转载请注明:“转载自:可持久化