【题目描述】
Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
HINT
HINT:
N,M<=100000
暴力自重。。。
Source
【题解】
强制在线,所以就不能离线乱搞了。
由于可持久化线段树是一个维护前缀和的东西,所以我们可以利用树上的前缀和来维护答案。
比如说你要知道权值范围l~r在u到v上有几个点,那么答案显然是:c[u]+c[v]-c[lca(u,v)]-c[fa[lca(u,v)]]。
所以对每个节点建树,继承它的父亲的树。
然后二分答案就可以了。
【Codes】
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int M=200010; int n,m,i,u,v,k,x,y,ans,cnt,edgenum; int a[M],b[M],c[M],head[M],vet[M],dep[M],next[M],root[M]; int fa[M][17]; struct node{int l,r,w;node(){l=r=w=0;}}tree[M*20]; bool cmp(int i,int j){return a[i]<a[j];} void addedge(int x,int y){ vet[++edgenum]=y; next[edgenum]=head[x]; head[x]=edgenum; } void change(int x,int l,int r,int &p){ tree[++cnt]=tree[p],p=cnt,tree[p].w++; if (l==r)return; int mid=l+r>>1; if (x<=mid)change(x,l,mid,tree[p].l);else change(x,mid+1,r,tree[p].r); } void dfs(int u){ int e,v; root[u]=root[fa[u][0]],change(c[u],1,n,root[u]); for (int i=1;i<=16;i++) if ((1<<i)<=dep[u])fa[u][i]=fa[fa[u][i-1]][i-1]; else break; for (e=head[u];e;e=next[e]){ v=vet[e]; if (fa[u][0]==v)continue; dep[v]=dep[u]+1; fa[v][0]=u; dfs(v); } } int lca(int x,int y){ if (dep[x]<dep[y])swap(x,y); int t=dep[x]-dep[y]; for (int i=0;i<=16;i++) if ((1<<i)&t)x=fa[x][i]; for (int i=16;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x==y)return x; return fa[x][0]; } int solve(int x,int y,int k){ int a=x,b=y,c=lca(x,y),d=fa[c][0]; a=root[a],b=root[b],c=root[c],d=root[d]; int l=1,r=n; while (l<r){ int mid=l+r>>1; int tmp=tree[tree[a].l].w+tree[tree[b].l].w-tree[tree[c].l].w-tree[tree[d].l].w; if (tmp>=k)r=mid,a=tree[a].l,b=tree[b].l,c=tree[c].l,d=tree[d].l; else k-=tmp,l=mid+1,a=tree[a].r,b=tree[b].r,c=tree[c].r,d=tree[d].r; } return l; } int main(){ scanf("%d%d",&n,&m); for (i=1;i<=n;i++){scanf("%d",&a[i]);b[i]=i;} sort(b+1,b+n+1,cmp); for (i=1;i<=n;i++)c[b[i]]=i; for (i=1;i<n;i++){ scanf("%d%d",&x,&y); addedge(x,y); addedge(y,x); } dfs(1); for (i=1;i<=m;i++){ scanf("%d%d%d",&u,&v,&k); ans=a[b[solve(u^ans,v,k)]]; if (i!=m)printf("%d\n",ans);else printf("%d",ans); } }