zkw segment-tree 真是太棒了(真的重口味)!写篇博客纪念入门
emmm…首先我们来介绍一下 zkw 线段树这个东西(俗称 “重口味” ,与 KMP 类似,咳咳…)
zkw 线段树的介绍
其实 zkw 线段树和普通线段树区别没多大(区别可大了去了!)
emmm…起码它们的思想是一致的,都是节点维护区间信息嘛。
只不过…普通线段树的维护和查询是递归式,而 zkw线段树是循环式的…
但是不要以为 zkw线段树只是靠循环加速上位的!
zkw线段树能支持非常多强(luan)如(qi)闪(ba)电(zao)的操作(最后例题讲)。
zkw 线段树 与普通线段树 的比较
emmm…这里你看着 普通线段树 的节点比 zkw线段树 的小对吧,但其实两者差不多,(因为线段树是要开4倍空间的啊,这里只是没有画出用不到的节点罢了),
zkw 线段树的形态
其实上图…还是无法体现zkw 线段树的具体形态的,(但是相信聪明的你一定看懂了所以我就不讲了)
emmm…于是乎还是上图解释一切
zkw 线段树的建立
首先你要写个循环,让 m 这个值(也就是非叶子节点)大于 n (也就是总叶子结点数),以此保证 这棵树的叶子 能够容纳你要维护的 n 个值
然后你要从 m 倒推 到 1 号节点(注意是 m 倒推回 1 ,保证维护每个节点时该节点的孩子都已经被维护完毕),让每个节点维护它左右孩子的信息。
代码实现
这里我们假设要维护的信息有:区间和,区间最小值,区间最大值 。 下同
inline void build(int n){ // 维护这么多信息都只需要这么几行,可见维护信息单一时代码应该会短的不像话(压行过的话大概三四行) for(m=1;m<=n;m<<=1); for(int i=m+1;i<=m+n;++i) sum[i]=mn[i]=mx[i]=read(); for(int i=m-1;i;--i) sum[i]=a[i<<1]+a[i<<1|1], mn[i]=min(mn[i<<1],mn[i<<1|1]), mx[i]=max(mx[i<<1],mx[i<<1|1]); }
但是,这里对 mn 和 mx 的处理是在无修改操作的基础上实行的,所以这样写并不支持修改操作。
那么我们可以这样写:
inline void build(){ for(m=1;m<=n;m<<=1); for(int i=m+1;i<=m+n;++i) sum[i]=mn[i]=mx[i]=read(); for(int i=m-1;i;--i){ sum[i]=sum[i<<1]+sum[i<<1|1]; mn[i]=min(mn[i<<1],mn[i<<1|1]), mn[i<<1]-=mn[i],mn[i<<1|1]-=mn[i]; mx[i]=max(mx[i<<1],mx[i<<1|1]), mx[i<<1]-=mx[i],mx[i<<1|1]-=mx[i]; } }
PS:以下的操作(单点、区间更新,单点、区间查询)所附的代码,都基于可修改的版本
zkw 线段树的更新
单点更新
这个单点更新还是比较好解决的,你只要找到更新的节点所在的叶子结点,然后修改后一直向父节点更新即可。
(这个。。。就不用上图了吧…你脑补一下就差不多了)
代码实现
这里我们假设将一个节点的值增加 v (修改的话…就记录一下原数组,然后算差值就好了吧?)
inline void update_node(int x,int v,int A=0){ x+=m,mx[x]+=v,mn[x]+=v;for(;x>1;x>>=1){ sum[x]+=v; A=min(mn[x],mn[x^1]); mn[x]-=A,mn[x^1]-=A,mn[x>>1]+=A; A=max(mx[x],mx[x^1]), mx[x]-=A,mx[x^1]-=A,mx[x>>1]+=A; } }
区间更新
这个东西…有点麻烦(你得稍微感性理解)。
就是说…你每次要更新一段区间的时候,你要让左端点 -1 ,右端点 +1 。
然后你在更新权值的时候要判断 左端点当前所处的节点是否是它父节点的左孩子,
是的话就让该节点的兄弟(也就是它父节点的右孩子)得到更新,否则不做处理,
然后左节点再向右移一位(也就是跳到了父节点),重复迭代以上步骤。
那么右端点呢?其实也就是和左端点反着来了而已。
还有一点,循环的终止条件?这个简单,就是当左右端点所处的节点是兄弟节点的时候结束循环。
类似的,你更新一个节点时 同样可以用这种方法维护(只不过这样就更麻烦了啊)。
这样我们可以看到要被更新的区间都已经被染成黄色了。但是,zkw 没有下传标记啊!
那么我们查询的区间如果在染成黄色的节点的下部(也就是黄色节点的子树内)该怎么办?
我们可以这样…这样…没错!标记永久化!
因为我们已经将一个节点的标记永久化了,那么在该节点被访问到的时候,只要将当前查询到的、包含在该节点所管辖区间范围内的 区间长度乘上标记值,累加入答案即可。
(具体实现得看代码)
区间更新的特殊情况
同学们有没有注意到一种区间查询的特殊情况?没错,就是右区间+1后到达下一层的特殊情况
就以上图为例,假设维护区间为 1 ~ 7 ,现在对 2 ~ 7 进行区间加操作,那么 t = 7+1 = 8 ,于是 t 就到达了不存在的第 5 层!
现在你想的一定是这种情况该怎么避免这种情况(其实很简单,你在建树确定 m 的值的时候,将判断条件改成 ” m<=n+1 ” 就行了)
但我现在要证明这种情况不需要避免也不会出问题(基本上…吧?)
我们可以看到,s 和 t 在跳到 0 和 1 时满足了终止条件,并且需要更新的节点都得到了更新,而且,其实 t 就没有更新过节点…
代码实现
这里我们假设要将一段区间的每个数加上 v ,然后维护的信息同上
inline void update_part(int s,int t,int v){ int A=0,lc=0,rc=0,len=1; for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ //在这里的 add 就是标记数组了 if(s&1^1) add[s^1]+=v,lc+=len, mn[s^1]+=v,mx[s^1]+=v; if(t&1) add[t^1]+=v,rc+=len, mn[t^1]+=v,mx[t^1]+=v; sum[s>>1]+=v*lc, sum[t>>1]+=v*rc; A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A, A=min(mn[t],mn[t^1]),mn[t]-=A,mn[t^1]-=A,mn[t>>1]+=A; A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A, A=max(mx[t],mx[t^1]),mx[t]-=A,mx[t^1]-=A,mx[t>>1]+=A; } for(lc+=rc;s>1;s>>=1){ sum[s>>1]+=v*lc; A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A, A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A; } }
这里的 lc 和 rc 的所代表的含义需要讲一下
lc 代表左端点所处的节点下有多少长度的区间在更新区间内, rc 同理 ,通俗一点地说,就是 s 和 t 所分别走过的节点中包含的更新过的区间的总长
zkw线段树的查询
单点查询
这个没什么好说的吧,你从叶子结点一直跳父节点,把途中节点的 mn (或者 mx )权值累加,最后得到的就是答案
代码实现
inline int query_node(int x,int ans=0){ for(x+=m;x;x>>=1) ans+=mn[s]; return ans; }
区间查询
什么?zkw线段树的区间查询?我不会啊。
那么这里的区间查询…其实有点难说啊!要不就直接上代码得了?咳咳…
这个其实和上面的区间更新的思路差不多,可能要讲的就是标记累加的问题了吧。
那么 lc 和 rc 之前已经讲过了,就是 s 节点和 t 节点分别走过的节点中所包含的更新区间的长度。
那么 add 这个数组啊…啊…啊…这个数组啊,它…要不我们直接看代码吧?
它好在哪里啊?好难说啊…其实它就是记录了你每次大块累加区间时的副产品啊,类似于线段树的懒标记。
但是和普通线段树不一样的是,线段树的查询是自上而下查询(顺便释放标记)然后又自下而上的递归回去的,
而 zkw 的查询是直接自下而上的,于是它无法释放标记,于是它就在遇到某个打过懒标记的节点时,将当前查询到的区间长度乘上标记值,累加入答案。
(所以这还是懒标记啊!不上图了自行脑补。emmm…算了吧那还是上一张图好了)
代码实现
inline int query_sum(int s,int t){ int lc=0,rc=0,len=1,ans=0; for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len; if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len; if(add[s>>1]) ans+=add[s>>1]*lc; if(add[t>>1]) ans+=add[t>>1]*rc; } for(lc+=rc,s>>=1;s;s>>=1) if(add[s]) ans+=add[s]*lc; return ans; } inline int query_min(int s,int t,int L=0,int R=0,int ans=0){ if(s==t) return query_node(s); // 单点要特判, 下同 for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){ // 这里 s 和 t 直接加上 m L+=mn[s],R+=mn[t]; if(s&1^1) L=min(L,mn[s^1]); if(t&1) R=min(R,mn[t^1]); } for(ans=min(L,R),s>>=1;s;s>>=1) ans+=mn[s]; return ans; } inline int query_max(int s,int t,int L=0,int R=0,int ans=0){ if(s==t) return query_node(s); for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){ L+=mx[s],R+=mx[t]; if(s&1^1) L=max(L,mx[s^1]); if(t&1) R=max(R,mx[t^1]); } for(ans=max(L,R),s>>=1;s;s>>=1) ans+=mx[s]; return ans; }
这里询问时 s 和 t 不能 -1 或 +1 ,不然会查询到旁边不相干的节点。
然后 s == t 的情况要特判一下,防止 s 和 t 一直都不是兄弟,陷入死循环。
zkw 的代码实现(模板)
完全代码
1 //by Judge 2 #include<cstdio> 3 #include<iostream> 4 using namespace std; 5 const int M=1e5+111; 6 //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 7 char buf[1<<21],*p1=buf,*p2=buf; 8 inline int read(){ 9 int x=0,f=1; char c=getchar(); 10 for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; 11 for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; 12 } 13 char sr[1<<21],z[20];int C=-1,Z; 14 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;} 15 inline void print(int x){ 16 if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x; 17 while(z[++Z]=x%10+48,x/=10); 18 while(sr[++C]=z[Z],--Z);sr[++C]=' '; 19 } 20 int n,m,q; 21 int sum[M<<2],mn[M<<2],mx[M<<2],add[M<<2]; 22 inline void build(){ 23 for(m=1;m<=n;m<<=1); 24 for(int i=m+1;i<=m+n;++i) 25 sum[i]=mn[i]=mx[i]=read(); 26 for(int i=m-1;i;--i){ 27 sum[i]=sum[i<<1]+sum[i<<1|1]; 28 mn[i]=min(mn[i<<1],mn[i<<1|1]), 29 mn[i<<1]-=mn[i],mn[i<<1|1]-=mn[i]; 30 mx[i]=max(mx[i<<1],mx[i<<1|1]), 31 mx[i<<1]-=mx[i],mx[i<<1|1]-=mx[i]; 32 } 33 } 34 inline void update_node(int x,int v,int A=0){ 35 x+=m,mx[x]+=v,mn[x]+=v,sum[x]+=v; 36 for(;x>1;x>>=1){ 37 sum[x]+=v; 38 A=min(mn[x],mn[x^1]); 39 mn[x]-=A,mn[x^1]-=A,mn[x>>1]+=A; 40 A=max(mx[x],mx[x^1]), 41 mx[x]-=A,mx[x^1]-=A,mx[x>>1]+=A; 42 } 43 } 44 inline void update_part(int s,int t,int v){ 45 int A=0,lc=0,rc=0,len=1; 46 for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ 47 if(s&1^1) add[s^1]+=v,lc+=len, mn[s^1]+=v,mx[s^1]+=v; 48 if(t&1) add[t^1]+=v,rc+=len, mn[t^1]+=v,mx[t^1]+=v; 49 sum[s>>1]+=v*lc, sum[t>>1]+=v*rc; 50 A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A, 51 A=min(mn[t],mn[t^1]),mn[t]-=A,mn[t^1]-=A,mn[t>>1]+=A; 52 A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A, 53 A=max(mx[t],mx[t^1]),mx[t]-=A,mx[t^1]-=A,mx[t>>1]+=A; 54 } 55 for(lc+=rc;s;s>>=1){ 56 sum[s>>1]+=v*lc; 57 A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A, 58 A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A; 59 } 60 } 61 inline int query_node(int x,int ans=0){ 62 for(x+=m;x;x>>=1) ans+=mn[x]; return ans; 63 } 64 inline int query_sum(int s,int t){ 65 int lc=0,rc=0,len=1,ans=0; 66 for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ 67 if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len; 68 if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len; 69 if(add[s>>1]) ans+=add[s>>1]*lc; 70 if(add[t>>1]) ans+=add[t>>1]*rc; 71 } 72 for(lc+=rc,s>>=1;s;s>>=1) if(add[s]) ans+=add[s]*lc; 73 return ans; 74 } 75 inline int query_min(int s,int t,int L=0,int R=0,int ans=0){ 76 if(s==t) return query_node(s); 77 for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){ 78 L+=mn[s],R+=mn[t]; 79 if(s&1^1) L=min(L,mn[s^1]); 80 if(t&1) R=min(R,mn[t^1]); 81 } 82 for(ans=min(L,R),s>>=1;s;s>>=1) ans+=mn[s]; 83 return ans; 84 } 85 inline int query_max(int s,int t,int L=0,int R=0,int ans=0){ 86 if(s==t) return query_node(s); 87 for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){ 88 L+=mx[s],R+=mx[t]; 89 if(s&1^1) L=max(L,mx[s^1]); 90 if(t&1) R=max(R,mx[t^1]); 91 } 92 for(ans=max(L,R),s>>=1;s;s>>=1) ans+=mx[s]; 93 return ans; 94 } 95 96 signed main(){ 97 98 99 100 101 102 return 0; 103 }
View Code
板子题?这个真没有…(不过你可以拿普通线段树的板子题等练手)
默默放上线段树板子题的链接…
1. 线段树 1
2. 线段树 2
代码
1.
1 //by Judge 2 #include<cstdio> 3 #include<iostream> 4 #define ll long long 5 using namespace std; 6 const int M=1e5+111; 7 //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 8 char buf[1<<21],*p1=buf,*p2=buf; 9 inline ll read(){ 10 ll x=0,f=1; char c=getchar(); 11 for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; 12 for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; 13 } 14 char sr[1<<21],z[20];int C=-1,Z; 15 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;} 16 inline void print(ll x){ 17 if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x; 18 while(z[++Z]=x%10+48,x/=10); 19 while(sr[++C]=z[Z],--Z);sr[++C]=' '; 20 } 21 ll n,m,q; 22 ll sum[M<<2],add[M<<2]; 23 inline void build(){ 24 for(m=1;m<=n;m<<=1); 25 for(int i=m+1;i<=m+n;++i) sum[i]=read(); 26 for(int i=m-1;i;--i) sum[i]=sum[i<<1]+sum[i<<1|1]; 27 } 28 inline void update_part(int s,int t,ll v){ 29 ll A=0,lc=0,rc=0,len=1; 30 for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ 31 if(s&1^1) add[s^1]+=v,lc+=len; 32 if(t&1) add[t^1]+=v,rc+=len; 33 sum[s>>1]+=v*lc,sum[t>>1]+=v*rc; 34 } for(lc+=rc,s>>=1;s;s>>=1) sum[s]+=v*lc; 35 } 36 inline ll query_sum(int s,int t){ 37 ll lc=0,rc=0,len=1,ans=0; 38 for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){ 39 if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len; 40 if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len; 41 if(add[s>>1]) ans+=add[s>>1]*lc; 42 if(add[t>>1]) ans+=add[t>>1]*rc; 43 } for(lc+=rc,s>>=1;s;s>>=1) if(add[s]) ans+=add[s]*lc; 44 return ans; 45 } 46 signed main(){ 47 n=read(),q=read(),build(); 48 int opt,x,y; ll k; 49 while(q--){ 50 opt=read(),x=read(),y=read(); 51 if(opt&1) k=read(),update_part(x,y,k); 52 else print(query_sum(x,y)); 53 } Ot(); return 0; 54 }
View Code
2.
emmm…实在是太晚啦(其实是没有研究过区间乘),所以就…您就自个儿研究吧~~~
推荐例题
题目: 无聊的数列
其实这道题用普通线段树 + 懒标记也可以做 (你可以试试?)
但是用了 zkw 之后…那个代码量的差别,我都不想说什么…(诶?貌似普通线段树用了标记永久化之后差不多也是这个码量?)
代码
1 //by Judge 2 #include<cstdio> 3 #include<iostream> 4 using namespace std; 5 const int M=1<<20; 6 int n,m,q,opt,L,R,k,d; 7 int a[M],lt[M],dt[M]; 8 //#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 9 char buf[1<<21],*p1=buf,*p2=buf; 10 inline int read(){ 11 int x=0,f=1; char c=getchar(); 12 for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; 13 for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; 14 } 15 char sr[1<<21],z[20];int C=-1,Z; 16 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;} 17 inline void print(int x){ 18 if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x; 19 while(z[++Z]=x%10+48,x/=10); 20 while(sr[++C]=z[Z],--Z);sr[++C]=' '; Ot(); 21 } 22 inline void build(){ 23 for(int i=m;i;--i) lt[i]=lt[i<<1]; 24 } 25 inline void update(int L,int R,int k,int d){ //update 还是蛮常规的 26 for(int l=L+m-1,r=R+m+1;l^r^1;l>>=1,r>>=1){ 27 if(l&1^1) a[l^1]+=k+(lt[l^1]-L)*d,dt[l^1]+=d; 28 if(r&1) a[r^1]+=k+(lt[r^1]-L)*d,dt[r^1]+=d; 29 } 30 } 31 inline int query(int p,int res){ //query 感性理解一下:非叶子节点存储的是附加值,也就是操作 1 当中加入的等差数列 32 for(int i=m+p;i;i>>=1) res+=a[i]+(p-lt[i])*dt[i]; 33 return res; 34 } 35 int main(){ 36 n=read(),q=read(); for(m=1;m<=n;m<<=1); printf("%d ",m); 37 for(int i=1;i<=n;++i) a[m+i]=read(),lt[m+i]=i; 38 build(); 39 while(q--){ 40 opt=read(); 41 if(opt&1) L=read(),R=read(),k=read(),d=read(),update(L,R,k,d); 42 else k=read(),print(query(k,0)); 43 } Ot(); return 0; 44 }
View Code
最后推荐一下: 某位大佬的 blog (写的也蛮详细的但没我详细,emmm…但是他那片博客里的区间求最值是错的,坑!)
最新评论