[总结][数据结构]ZKW线段树详解

一些区间统计类型的题,如果满足区间减法的话,恐怕大家都会用树状数组对吧。。。

原因有两点:壹、好写;贰、常数小。

然而ZKW天牛发明的ZKW线段树也具有以上的这些优点,而且更加灵活。

对于需要区间延迟标记的题目,也完全没有必要转化为很多不同的信息加以维护,通过一些做法,可以做到像递归版线段树一样下传标记,而且代码简单,常数极小。(但确实有一些题目只能用递归版线段树才能处理,事实上,递归版线段树具有通用性。)

ZKW线段树的来源是一篇名为《统计的力量》的PPT,看完使人醍醐灌顶。建议先好好看看这个~

传送门

ZKW线段树的存储方式为堆式存储,这个不用多说了吧~~

首先假设我们维护的区间是,由于我们知道查询的实际上是(两端)开区间,所以底层的区间事实上是,以用来应对这种询问。因此,底层至少有个节点。

设底层的节点数为,则可以用这样一段代码处理:

for(M=1;M<(n+2);M<<=1);

时,易知,我们给出此时的线段树示例:

。。。由于我是蒟蒻就凑合看吧。。这样就初步理解了吧。。。

这样单点修改和区间查询就都很容易解决了,思想都很简单就不多说了。单点修改就是找到线段树中对应位置的叶子节点更新,再逐步向上维护父节点;区间查询最为经典,理解完思想后直接一个循环解决问题。区间和,最值,连续和最值什么的都不在话下了吧。不明白的照着上面的图结合PPT手动模拟一下也都理解了吧。。。

激动地刷一些水题,发现常数果然比递归版高大上许多。。

/*****************我是分割线***************************************/

当我看到“永久化标记”的时候,顿时不明觉厉。。然后就不想继续看了|T T

然后我就熬夜ORZ神犇,写出了一种把递归版写成非递归的暴力方法。。

事实上问题在于,如果遇到需要区间延迟标记的题目,用非递归的写法在处理询问时,(如果像处理询问一样依次给修改区间打标记),上层节点的标记并没有传到当前节点,从而导致答案错误。

这是由于上层节点的标记是比下层节点标记的优先级高的,由于递归版自顶至下,而非递归版自下至上,所以下层节点没有传到,就呵呵了。

但是修改的时候显然也不能暴力的给区间内包含的所有节点打标记,这样就真的变成暴力了。。。

于是我们考虑,也许可以令部分节点标记下传,从而是当前询问区间的答案正确。修改哪些呢,YY了一下,假如要对一段区间进行操作(查询或是修改),先用循环找到左侧和右侧第一个应该被算入答案(或应该被修改的节点),就把这个节点至根的路径上所有的点的标记自上而下下传,因此,实现中用到了栈。此外,如果修改的话,循环结束后,还应该对于两侧的上述节点至根的路径上进行自下而上维护信息。这样每次维护都能保证,而且答案也比较科学。

就以一道简单的题目为例子说明吧。。。POJ3468支持区间加一个数,区间求和。详情见注释。

代码:

#include <cstdio>
#include <cstring>
#include <cctype>
#define N ((131072 << 1) + 10) //表示节点个数->不小于区间长度+2的最小2的正整数次幂*2+10
typedef long long LL;
inline int getc() {
    static const int L = 1 << 15;
    static char buf[L] , *S = buf , *T = buf;
    if (S == T) {
        T = (S = buf) + fread(buf , 1 , L , stdin);
        if (S == T)
            return EOF;
    }
    return *S++;
}
inline int getint() {
    static char c;
    while(!isdigit(c = getc()) && c != '-');
    bool sign = (c == '-');
    int tmp = sign ? 0 : c - '0';
    while(isdigit(c = getc()))
        tmp = (tmp << 1) + (tmp << 3) + c - '0';
    return sign ? -tmp : tmp;
}
inline char getch() {
    char c;
    while((c = getc()) != 'Q' && c != 'C');
    return c;
}
int M; //底层的节点数
int dl[N] , dr[N]; //节点的左右端点
LL sum[N]; //节点的区间和
LL add[N]; //节点的区间加上一个数的标记
#define l(x) (x<<1) //x的左儿子,利用堆的性质
#define r(x) ((x<<1)|1) //x的右儿子,利用堆的性质
void pushdown(int x) { //下传标记
 if (add[x]&&x<M) {//如果是叶子节点,显然不用下传标记(别忘了)
     add[l(x)] += add[x];
        sum[l(x)] += add[x] * (dr[l(x)] - dl[l(x)] + 1);
        add[r(x)] += add[x];
        sum[r(x)] += add[x] * (dr[r(x)] - dl[r(x)] + 1);
        add[x] = 0; 
    }
}
int stack[20] , top;//栈
void upd(int x) { //下传x至根节点路径上节点的标记(自上而下,用栈实现)
 top = 0;
    int tmp = x;
    for(; tmp ; tmp >>= 1)
        stack[++top] = tmp;
    while(top--)
        pushdown(stack[top]);
}
LL query(int tl , int tr) { //求和
 LL res=0;
    int insl = 0, insr = 0; //两侧第一个有用节点
 for(tl=tl+M-1,tr=tr+M+1;tl^tr^1;tl>>=1,tr>>=1) {
        if (~tl&1) {
            if (!insl)
        upd(insl=tl^1);
            res+=sum[tl^1];
        }
        if (tr&1) {
            if(!insr)
        upd(insr=tl^1)
            res+=sum[tr^1];
        }
    }
    return res;
}
void modify(int tl , int tr , int val) { //修改
 int insl = 0, insr = 0;
    for(tl=tl+M-1,tr=tr+M+1;tl^tr^1;tl>>=1,tr>>=1) {
        if (~tl&1) {
            if (!insl)
                upd(insl=tl^1);
            add[tl^1]+=val;
            sum[tl^1]+=(LL)val*(dr[tl^1]-dl[tl^1]+1);
        }
        if (tr&1) {
            if (!insr)
                upd(insr=tr^1);
            add[tr^1]+=val;
            sum[tr^1]+=(LL)val*(dr[tr^1]-dl[tr^1]+1);
        }
    }
    for(insl=insl>>1;insl;insl>>=1) //一路update
     sum[insl]=sum[l(insl)]+sum[r(insl)];
    for(insr=insr>>1;insr;insr>>=1)
        sum[insr]=sum[l(insr)]+sum[r(insr)];
        
        
}
inline void swap(int &a , int &b) {
    int tmp = a;
    a = b;
    b = tmp;
}
int main() {
    //freopen("tt.in" , "r" , stdin);
 int n , ask;
    n = getint();
    ask = getint();
    int i;
    for(M = 1 ; M < (n + 2) ; M <<= 1);
    for(i = 1 ; i <= n ; ++i)
        sum[M + i] = getint() , dl[M + i] = dr[M + i] = i; //建树
 for(i = M - 1; i >= 1 ; --i) { //预处理节点左右端点
     sum[i] = sum[l(i)] + sum[r(i)];
        dl[i] = dl[l(i)];
        dr[i] = dr[r(i)];
    }
    char s;
    int a , b , x;
    while(ask--) {
        s = getch();
        if (s == 'Q') {
            a = getint();
            b = getint();
            if (a > b)
                swap(a , b);
            printf("%lld\n" , query(a , b));
        }
        else {
            a = getint();
            b = getint();
            x = getint();
            if (a > b)
                swap(a , b);
            modify(a , b , x);
        }
    }
    return 0;
}
comments powered by Disqus