问:Splay在时间和编程复杂度上都被碾压,为什么要学呢?
答:Splay是万金油啊,完爆各种乱七八糟的维护。(本质是懒得学新数据结构)
问:那么我阔以不看冗长的介绍吗?
答:其实窝并没有想要介绍。。。
【简介】
首先Spaly是个二叉搜索树,如果连这个都不知道,戳这
其次Splay有一个奇妙的提根操作,同时也是Splay的核心操作。
void rotate(int x,int &k){ int y=fa[x],z=fa[y],l,r; if (son[y][0]==x)l=0;else l=1;r=l^1; if (y==k)k=x; else {if (son[z][0]==y)son[z][0]=x;else son[z][1]=x;} fa[x]=z;fa[y]=x;fa[son[x][r]]=y; son[y][l]=son[x][r];son[x][r]=y; pushup(y);pushup(x); } void splay(int x,int &k){ while (x!=k){ int y=fa[x],z=fa[y]; if (y!=k){ if (son[y][0]==x^son[z][0]==y)rotate(x,k);else rotate(y,k); } rotate(x,k); } }
通过这个可以证明所有操作的均摊复杂度为4*logn。
显然加点、删点、询问第k大、询问一个数是第几大……这些操作都是很好写的。
那么如何给区间加值、求区间最值、区间反转呢?
窝萌阔以参照线段树的tag来给平衡树打tag,然后每次操作完后pushup、pushdown。
区间反转,其实就是左右儿子交换,不过这种情况要加两个点防止爆炸。
有兴趣的同学戳这
【例题】
Bzoj1552&3506
好像是傻逼题?加个区间反转和区间最值就能A。
Bzoj3173
基于数是一个一个插入的,显然插入的数比前面的数都大。
那么只要求区间最值来暴推lis就阔以辣!
【Codes】
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int M=100010; const int inf=1000000007; int n,i,cnt,root; int fa[M],son[M][2],rev[M],data[M],size[M],mx[M],pos[M],s[M],ans[M]; struct node{int v,pos;}a[M]; bool operator<(node x,node y){return x.pos<y.pos;} bool cmp(node x,node y){return x.v<y.v||(x.v==y.v&&x.pos<y.pos);} void pushup(int p){ int l=son[p][0],r=son[p][1]; pos[p]=p;mx[p]=data[p]; if (mx[l]<mx[p]){mx[p]=mx[l];pos[p]=pos[l];} if (mx[r]<mx[p]){mx[p]=mx[r];pos[p]=pos[r];} size[p]=size[l]+size[r]+1; } void pushdown(int p){ int l=son[p][0],r=son[p][1]; if (rev[p]){ rev[p]=0; rev[l]^=1;rev[r]^=1; swap(son[p][0],son[p][1]); } } void build(int l,int r,int p){ if (l>r)return; if (l==r){ fa[l]=p;size[l]=1; if (l<p)son[p][0]=l;else son[p][1]=l; pos[l]=l;mx[l]=data[l]=a[l].v; return; } int mid=l+r>>1; build(l,mid-1,mid); build(mid+1,r,mid); fa[mid]=p;data[mid]=a[mid].v;pushup(mid); if (mid<p)son[p][0]=mid;else son[p][1]=mid; } void rotate(int x,int &k){ int y=fa[x],z=fa[y],l,r; if (son[y][0]==x)l=0;else l=1;r=l^1; if (y==k)k=x; else {if (son[z][0]==y)son[z][0]=x;else son[z][1]=x;} fa[x]=z;fa[y]=x;fa[son[x][r]]=y; son[y][l]=son[x][r];son[x][r]=y; pushup(y);pushup(x); } void splay(int x,int &k){ int top=0;s[++top]=x; for (int i=x;fa[i];i=fa[i])s[++top]=fa[i]; for (int i=top;i;i--)if(rev[s[i]])pushdown(s[i]); while (x!=k){ int y=fa[x],z=fa[y]; if (y!=k){ if (son[y][0]==x^son[z][0]==y)rotate(x,k); else rotate(y,k); } rotate(x,k); } } int findkth(int k,int rank){ if (rev[k])pushdown(k); int l=son[k][0],r=son[k][1]; if (size[l]+1==rank)return k; else if (size[l]>=rank)return findkth(l,rank); else return findkth(r,rank-size[l]-1); } void reverse(int l,int r){ int x=findkth(root,l),y=findkth(root,r+2); splay(x,root);splay(y,son[x][1]); rev[son[y][0]]^=1; } int querymx(int l,int r){ int x=findkth(root,l),y=findkth(root,r+2); splay(x,root);splay(y,son[x][1]); return pos[son[y][0]]; } int main(){ scanf("%d",&n); a[1].v=a[n+2].v=inf;mx[0]=inf; for (i=2;i<=n+1;i++)scanf("%d",&a[i].v),a[i].pos=i; sort(a+2,a+n+2,cmp); for (i=2;i<=n+1;i++)a[i].v=i-1; sort(a+2,a+n+2); build(1,n+2,0); root=n+3>>1;cnt=n+2; for (i=1;i<=n;i++){ splay(querymx(i,n),root); ans[i]=size[son[root][0]]; reverse(i,ans[i]); } for (i=1;i<n;i++)printf("%d ",ans[i]); printf("%d\n",ans[n]); }
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int M=100010; int n,i,x,cnt,root; int fa[M],son[M][2],size[M],mx[M],data[M],ans[M]; void pushup(int p){ int l=son[p][0],r=son[p][1]; mx[p]=max(data[p],max(mx[l],mx[r])); size[p]=size[l]+size[r]+1; } void rotate(int x,int &k){ int y=fa[x],z=fa[y],l,r; if (son[y][0]==x)l=0;else l=1;r=l^1; if (y==k)k=x; else {if (son[z][0]==y)son[z][0]=x;else son[z][1]=x;} fa[x]=z;fa[y]=x;fa[son[x][r]]=y; son[y][l]=son[x][r];son[x][r]=y; pushup(y);pushup(x); } void splay(int x,int &k){ while (x!=k){ int y=fa[x],z=fa[y]; if (y!=k){ if (son[y][0]==x^son[z][0]==y)rotate(x,k);else rotate(y,k); } rotate(x,k); } } int findkth(int k,int rank){ int l=son[k][0],r=son[k][1]; if (size[l]+1==rank)return k; else if (size[l]>=rank)return findkth(l,rank); else return findkth(r,rank-size[l]-1); } int main(){ mx[0]=-1000000007; root=1;cnt=2;son[1][0]=2;fa[2]=1;size[1]=2;size[2]=1; scanf("%d",&n); for (i=1;i<=n;i++){ scanf("%d",&x); splay(findkth(root,x+2),root); int p=findkth(root,x+1); splay(p,son[root][0]); son[p][1]=++cnt; data[cnt]=mx[cnt]=ans[i]=mx[p]+1; fa[cnt]=p;size[cnt]=1; pushup(p);pushup(root); ans[i]=max(ans[i],ans[i-1]); } for (i=1;i<=n;i++)printf("%d\n",ans[i]); }