[总结][树链剖分]BZOJ1036,2243,3531

被大JLOI2014虐成翔,人民大众纷纷水过的树链剖分只有暴力50.。。回来后发奋图强搞掉树链剖分。。。

其实弄明白之后树链剖分很简单。

上述信息均可用一次dfs处理。代码如下:

void dfs_build(int x , int fa) //传入当前节点和当前节点的父亲
{
    size[x] = 1; //记录当前节点的size 值
    int nowmax = -1; //当前儿子的最大size值,初始化为-1
    for(register int j = head[x] ; j ; j = next[j]) //链式前向星建图,存储双向边
    {
        if (end[j] != fa) //防止指回自己的父亲
        {
            pa[end[j]] = x;
            depth[end[j]] = depth[x] + 1;
            dfs_build(end[j] , x);
            size[x] += size[end[j]];
            if (size[end[j]] > nowmax) //更新size值最大的儿子
            {
                nowmax = size[end[j]];
                weighson[x] = end[j];
            }
        }
    }
}

主函数调用时,调用dfs(choose_root,-1).choose_root表示你选择的有根树的根,第二个参数可以随意设,但要保证不在
中。

设树的总节点数目为N,则依据下面的规则把N-1条边划分为轻边和重边。重边,即父亲连接自己的重儿子的边。不满足这样条件的边称为轻边。

定义重链为只由重边连续组成的一条路径,且不可再向两侧拓展。

对于一个与父节点连接的边是轻边的叶子节点,也可以说这个叶子节点单独构成了一条退化了的重链(只有一个节点)。

这样,显然每个点均属于一条重链且仅属于一条重链,重链与重链之间以一条轻边连接。(自己画图或到网上找。)

定义为点x所在重链上的depth值最小的点的标号。

如果我们把所有重链取下来连续地放在一条直线上,给每条重链一段连续的区间,给一条重链上的点连续的位置标号,则对于一条路径,不难发现一定可以把路径拆分为若干条重链和若干重链的部分,而每一条重链相当于一段连续的区间,于是可以轻松使用线段树维护查询每一条重链,进而处理路径上的问题。

定义表示树上标号为x的点在线段树中被分配的标号。

定义表示线段树中被分配标号为x的位置对应的树中的节点的标号。

进行第二次dfs维护上述三个信息。代码如下:

void dfs_create(int x , int Top) //参数分别为当前节点,当前节点所在重链上depth最小的点
{
    top[x] = Top; //维护当前节点top值
    p_id[x] = ++id;  //给当前节点一个标号
    id_p[id] = x;
    if (weighson[x])  //优先沿着重儿子匹配标号以使得一条重链在线段树上有着连续的区间
        dfs_create(weighson[x] , Top);
    for(register int j = head[x] ; j ; j = next[j])
        if (end[j] != pa[x] && end[j] != weighson[x]) //如果指向节点是轻儿子
            dfs_create(end[j] , end[j]);  //以轻儿子为depth值最小的点创建一条新的重链
}

最后只需要依据询问,在线段树中修改查询即可。详情见下面的题目。

【BZOJ1036】
不能再裸的裸题。
想要学习的人看一看代码吧。

#include<cstdio>
#include<cstring>
#define N (30010)
int head[N] , next[N << 1] , end[N << 1];
void addedge(int a , int b)
{
    static int q = 1;
    end[q] = b;
    next[q] = head[a];
    head[a] = q++;
}
int size[N] , Weighson[N] , depth[N],  pa[N];
void dfs_build(int x , int fa)
{
    size[x] = 1;
    int now_max_size = -1;
    for(register int j = head[x] ; j ; j = next[j])
    {
        if (end[j] != fa)
        {
            depth[end[j]] = depth[x] + 1;
            pa[end[j]] = x;
            dfs_build(end[j] , x);
            size[x] += size[end[j]];
            if (size[end[j]] > now_max_size)
            {
                now_max_size = size[end[j]];
                Weighson[x] = end[j];
            }
        }
    }
}
int top[N] , id , point_id[N] , id_point[N];
void dfs_create(int x , int Top)
{
    top[x] = Top;
    point_id[x] = ++id;
    id_point[id] = x;
    if (Weighson[x])
        dfs_create(Weighson[x] , Top);
    for(register int j = head[x] ; j ; j = next[j])
        if (end[j] != Weighson[x] && end[j] != pa[x])
            dfs_create(end[j] , end[j]);
}
inline int _max(int a , int b)
{
    return a > b ? a : b;
}
int save[N];
int ind , dl[N << 2] , dr[N << 2] , l[N << 2] , r[N << 2] , Max[N << 2] , Sum[N << 2];
int build(int tl , int tr)
{
    int q = ++ind;
    dl[q] = tl;
    dr[q] = tr;
    if (tl == tr)
    {
       l[q] = r[q] = 0;
       Max[q] = Sum[q] = save[id_point[tl]];
       return q;
    }
    int mid = (tl + tr) >> 1;
    l[q] = build(tl , mid);
    r[q] = build(mid + 1 , tr);
    Max[q] = _max(Max[l[q]] , Max[r[q]]);
    Sum[q] = Sum[l[q]] + Sum[r[q]];
    return q;
}
int query_ADD(int tl , int tr , int q)
{
    if (tl <= dl[q] && dr[q] <= tr)
        return Sum[q];
    int mid = (dl[q] + dr[q]) >> 1;
    if (tl > mid)
        return query_ADD(tl , tr , r[q]);
    else if (tr <= mid)
        return query_ADD(tl , tr , l[q]);
    else
        return query_ADD(tl , mid , l[q]) + query_ADD(mid + 1 , tr , r[q]);
}
int query_MAX(int tl , int tr , int q)
{
    if (tl <= dl[q] && dr[q] <= tr)
        return Max[q];
    int mid = (dl[q] + dr[q]) >> 1;
    if (tl > mid)
        return query_MAX(tl , tr , r[q]);
    else if (tr <= mid)
        return query_MAX(tl , tr , l[q]);
    else
        return _max(query_MAX(tl , mid , l[q]) , query_MAX(mid + 1 , tr , r[q]));
}
void modify(int ins , int to , int q)
{
    if (dl[q] == dr[q])
    {
        Sum[q] = Max[q] = to;
        return;
    }
    int mid = (dl[q] + dr[q]) >> 1;
    if (ins <= mid)
        modify(ins , to , l[q]);
    else
        modify(ins , to , r[q]);
    Max[q] = _max(Max[l[q]] , Max[r[q]]);
    Sum[q] = Sum[l[q]] + Sum[r[q]];
}
inline void swap(int &a , int &b)
{
    int tmp = a;
    a = b;
    b = tmp;
}
int ask_max(int a , int b)
{
    int ret = -1 << 30;
    while (top[a] != top[b])
    {
        if (depth[top[a]] < depth[top[b]])
            swap(a,  b);
        ret = _max(ret , query_MAX(point_id[top[a]] , point_id[a] , 1));
        a = pa[top[a]];
    }
    if (depth[a] < depth[b])
        swap(a , b);
    ret = _max(ret , query_MAX(point_id[b] , point_id[a] , 1));
    return ret;
}
int ask_sum(int a , int b)
{
    int ret = 0;
    while(top[a] != top[b])
    {
        if (depth[top[a]] < depth[top[b]])
            swap(a , b);
        ret += query_ADD(point_id[top[a]] , point_id[a] , 1);
        a = pa[top[a]];
    }
    if (depth[a] < depth[b])
        swap(a , b);
    ret += query_ADD(point_id[b] , point_id[a] , 1);
    return ret;
}
int main()
{
    //freopen("tt.in" , "r" , stdin);
    //freopen("tt.out" , "w" , stdout);
    int n;
    scanf("%d" , &n);
    int a , b;
    register int i;
    for(i = 1 ; i < n ; ++i)
    {
        scanf("%d%d" , &a , &b);
        addedge(a , b);
        addedge(b , a);
    }
    dfs_build(1 , -1);
    dfs_create(1 , 1);
    for(i = 1 ; i <= n ; ++i)
        scanf("%d" , &save[i]);
    build(1 , id);
    int ask;
    scanf("%d" , &ask);
    char ch[10];
    while(ask--)
    {
        scanf("%s",  ch);
        if (ch[0] == 'Q' && ch[1] == 'M')
        {
            scanf("%d%d" , &a , &b);
            printf("%d\n" , ask_max(a , b));
        }
        if (ch[0] == 'Q' && ch[1] == 'S')
        {
            scanf("%d%d" , &a , &b);
            printf("%d\n" , ask_sum(a , b));
        }
        if (ch[0] == 'C')
        {
            scanf("%d%d" , &a , &b);
            modify(point_id[a] , b , 1);
        }
    }
    return 0;
}

【BZOJ2243】
这个题不是那么裸了,弱渣看了题解才想出来。。。
先树链剖分,然后对于每段区间维护区间左端颜色,区间右端颜色,区间内颜色段数。很容易维护。
至于最后询问路径的时候,如果发现当前重链的top位置的颜色与他的父亲的颜色相同,则答案减1,因为没有新的颜色段出现。
代码:

#include<cstdio>
#include<cstring>
#include<cctype>
#define N (100010)
#define INF (1 << 30)
inline char getc()
{
    static const int L = 1 << 15;
    static char buf[L] , *S = buf , *T = buf;
    if (S == T)
    {
        T = (S = buf) + fread(buf , 1 , L , stdin);
        if (S == T)
            return EOF;
    }
    return *S++;
}
inline int getint()
{
    static char c;
    while(!isdigit(c = getc()));
    int tmp = c - '0';
    while(isdigit(c = getc()))
        tmp = (tmp << 1) + (tmp << 3) + c - '0';
    return tmp;
}
inline char getch()
{
    static char c;
    while((c = getc()) != 'C' && c != 'Q');
    return c;
}
int head[N] , next[N << 1] , end[N << 1];
void addedge(int a , int b)
{
    static int q = 1;
    end[q] = b;
    next[q] = head[a];
    head[a] = q++;
}
int depth[N] , pa[N] , size[N] , Weighson[N];
void dfs_build(int x , int fa)
{
    size[x] = 1;
    int Maxsize = -1;
    for(register int j = head[x] ; j ; j = next[j])
    {
        if (end[j] != fa)
        {
            depth[end[j]] = depth[x] + 1;
            pa[end[j]] = x;
            dfs_build(end[j] , x);
            size[x] += size[end[j]];
            if (size[end[j]] > Maxsize)
            {
                Maxsize = size[end[j]];
                Weighson[x] = end[j];
            }
        }
    }
}
int top[N] , p_id[N] , id_p[N] , id;
void dfs_create(int x , int Top)
{
    top[x] = Top;
    p_id[x] = ++id;
    id_p[id] = x;
    if (Weighson[x])
        dfs_create(Weighson[x] , Top);
    for(register int j = head[x] ; j ; j = next[j])
        if (end[j] != pa[x] && end[j] != Weighson[x])
            dfs_create(end[j] , end[j]);
}
int col[N];
int dl[N << 2] , dr[N << 2] , l[N << 2] , r[N << 2] , left[N << 2] , right[N << 2] , same[N << 2] , num[N << 2] , ind;
void update(int x)
{
    num[x] = num[l[x]] + num[r[x]] - (left[r[x]] == right[l[x]]);
    left[x] = left[l[x]];
    right[x] = right[r[x]];
}
void pushsame(int x , int add)
{
    left[x] = right[x] = same[x] = add;
    num[x] = 1;
}
void pushdown(int x)
{
    if (same[x] != -INF)
    {
        if (l[x])
            pushsame(l[x] , same[x]);
        if (r[x])
            pushsame(r[x] , same[x]);
        same[x] = -INF;
    }
}
int build(int tl , int tr)
{
    int q = ++ind;
    same[q] = -INF;
    dl[q] = tl;
    dr[q] = tr;
    if (tl == tr)
    {
        left[q] = right[q] = col[id_p[tl]];
        num[q] = 1;
        return q;
    }
    int mid = (tl + tr) >> 1;
    l[q] = build(tl , mid);
    r[q] = build(mid + 1 , tr);
    update(q);
    return q;
}
void modify(int tl , int tr , int add , int q)
{
    pushdown(q);
    if (tl <= dl[q] && dr[q] <= tr)
    {
        pushsame(q , add);
        return;
    }
    int mid = (dl[q] + dr[q]) >> 1;
    if (tl > mid)
        modify(tl , tr , add , r[q]);
    else if (tr <= mid)
        modify(tl , tr , add , l[q]);
    else
        modify(tl , mid , add , l[q]) , modify(mid + 1 , tr , add , r[q]);
    update(q);
    return;
}
int ask_col(int ins , int q)
{
    pushdown(q);
    if (dl[q] == dr[q])
        return left[q];
    int mid = (dl[q] + dr[q]) >> 1;
    if (ins <= mid)
        return ask_col(ins , l[q]);
    else
        return ask_col(ins , r[q]);
}
int ask_num(int tl , int tr , int q)
{
    pushdown(q);
    if (tl <= dl[q] && dr[q] <= tr)
        return num[q];
    int mid = (dl[q] + dr[q]) >> 1;
    if (tl > mid)
        return ask_num(tl , tr , r[q]);
    else if (tr <= mid)
        return ask_num(tl , tr , l[q]);
    else
        return ask_num(tl , mid , l[q]) + ask_num(mid + 1 , tr , r[q]) - (ask_col(mid , 1) == ask_col(mid + 1 , 1));
}
inline void swap(int &a , int &b)
{
    int tmp = a;
    a = b;
    b = tmp;
}
void call_modify(int a , int b , int to)
{
    while(top[a] != top[b])
    {
        if (depth[top[a]] < depth[top[b]])
            swap(a , b);
        modify(p_id[top[a]] , p_id[a] , to , 1);
        a = pa[top[a]];
    }
    if (depth[a] < depth[b])
        swap(a , b);
    modify(p_id[b] , p_id[a] , to , 1);
}
int getans(int a , int b)
{
    int ret = 0;
    while(top[a] != top[b])
    {
        if (depth[top[a]] < depth[top[b]])
            swap(a , b);
        ret += ask_num(p_id[top[a]] , p_id[a] , 1);
        //now = ask_col(p_id[a] , 1);
        if (ask_col(p_id[top[a]] , 1) == ask_col(p_id[pa[top[a]]] , 1))
            --ret;
        //last[cur] = ask_col(p_id[top[a]] , 1);
        a = pa[top[a]];
    }
    if (depth[a] < depth[b])
        swap(a , b);
    ret += ask_num(p_id[b] , p_id[a] , 1);
    return ret;
}
int main()
{
    //freopen("paint7.in" , "r" , stdin);
    //freopen("tt.out" , "w" , stdout);
    int n , m;
    //scanf("%d%d" , &n , &m);
    n = getint() , m = getint();
    register int i;
    for(i = 1 ; i <= n ; ++i)
        //scanf("%d" , &col[i]);
        col[i] = getint();
    int a , b , x;
    for(i = 1 ; i < n ; ++i)
    {
        //scanf("%d%d" , &a , &b);
        a = getint() , b = getint();
        addedge(a , b);
        addedge(b , a);
    }
    dfs_build(1 , -1);
    dfs_create(1 , 1);
    build(1 , id);
    //char s[10];
    char ch;
    while(m--)
    {
        ch = getch();
        a = getint();
        b = getint();
        //scanf("%s%d%d" , s , &a , &b);
        //switch (s[0])
        switch(ch)
        {
            case 'Q':
            {
                printf("%d\n" , getans(a , b));
                break;
            }
            case 'C':
            {
                //scanf("%d" , &x);
                x = getint();
                call_modify(a , b , x);
            }
        }
    }
    return 0;
}

【BZOJ3531】
暴力维护每一种信仰的线段树,每次修改时动态开点。
最终的总点数不超过.
代码:

#include<cstdio>
#include<cstring>
#define  N (100010)
int head[N] , next[N << 1] , end[N << 1];
void addedge(int a , int b)
{
    static int q = 1;
    end[q] = b;
    next[q] = head[a];
    head[a] = q++;
}
int size[N] , weighson[N] , pa[N] , depth[N];
void dfs_build(int x , int fa)
{
    size[x] = 1;
    int nowmax = -1;
    for(register int j = head[x] ; j ; j = next[j])
    {
        if (end[j] != fa)
        {
            pa[end[j]] = x;
            depth[end[j]] = depth[x] + 1;
            dfs_build(end[j] , x);
            size[x] += size[end[j]];
            if (size[end[j]] > nowmax)
            {
                nowmax = size[end[j]];
                weighson[x] = end[j];
            }
        }
    }
}
int top[N] , p_id[N] , id_p[N] , id;
void dfs_create(int x , int Top)
{
    top[x] = Top;
    p_id[x] = ++id;
    id_p[id] = x;
    if (weighson[x])
        dfs_create(weighson[x] , Top);
    for(register int j = head[x] ; j ; j = next[j])
        if (end[j] != pa[x] && end[j] != weighson[x])
            dfs_create(end[j] , end[j]);
}
int save_col[N] , value[N];
struct Node
{
    int dl , dr , l , r , max , add;
}S[4200000];
#define dl(x) S[x].dl
#define dr(x) S[x].dr
#define l(x) S[x].l
#define r(x) S[x].r
#define max(x) S[x].max
#define add(x) S[x].add
inline int _max(int a , int b)
{
    return a > b ? a : b;
}
int ind;
void update(int x)
{
    max(x) = _max(max(l(x)) , max(r(x)));
    add(x) = add(l(x)) + add(r(x));
}
#define Maxcol (100010)
int root[Maxcol];
void modify(int ins , int val , int q)
{
    if (!q)
        return;
    if (dl(q) == dr(q))
    {
        max(q) = add(q) = val;
        return;
    }
    int mid = (dl(q) + dr(q)) >> 1;
    if (ins <= mid)
    {
        if (!l(q))
        {
            l(q) = ++ind;
            dl(ind) = dl(q);
            dr(ind) = mid;
        }
        modify(ins , val , l(q));
    }
    else
    {
        if (!r(q))
        {
            r(q) = ++ind;
            dl(ind) = mid + 1;
            dr(ind) = dr(q);
        }
        modify(ins , val , r(q));
    }
    update(q);
    return;
}
int query_add(int tl , int tr , int q)
{
    if (!q)
        return 0;
    if (tl <= dl(q) && dr(q) <= tr)
        return add(q);
    int mid = (dl(q) + dr(q)) >> 1;
    if (tl > mid)
        return query_add(tl , tr , r(q));
    else if (tr <= mid)
        return query_add(tl , tr , l(q));
    else
        return query_add(tl , mid , l(q)) + query_add(mid + 1 , tr , r(q));
}
int query_max(int tl , int tr , int q)
{
    if (!q)
        return 0;
    if (tl <= dl(q) && dr(q) <= tr)
        return max(q);
    int mid = (dl(q) + dr(q)) >> 1;
    if (tl > mid)
        return query_max(tl , tr , r(q));
    else if (tr <= mid)
        return query_max(tl , tr , l(q));
    else
        return _max(query_max(tl , mid , l(q)) , query_max(mid + 1 , tr , r(q)));
}
inline void swap(int &a , int &b)
{
    int tmp = a;
    a = b;
    b = tmp;
}
int ask_add(int a , int b , int col)
{
    int ret = 0;
    while(top[a] != top[b])
    {
        if (depth[top[a]] < depth[top[b]])
            swap(a , b);
        ret += query_add(p_id[top[a]] , p_id[a] , root[col]);
        a = pa[top[a]];
    }
    if (depth[a] < depth[b])
        swap(a , b);
    ret += query_add(p_id[b] , p_id[a] , root[col]);
    return ret;
}
int ask_max(int a , int b , int col)
{
    int ret = 0;
    while(top[a] != top[b])
    {
        if (depth[top[a]] < depth[top[b]])
            swap(a , b);
        ret = _max(ret , query_max(p_id[top[a]] , p_id[a] , root[col]));
        a = pa[top[a]];
    }
    if (depth[a] < depth[b])
        swap(a , b);
    ret = _max(ret , query_max(p_id[b] , p_id[a] , root[col]));
    return ret;
}
int main()
{
    //freopen("tt.in" , "r" , stdin);
    int n , ask;
    scanf("%d%d" , &n , &ask);
    register int i;
    for(i = 1 ; i <= n ; ++i)
        scanf("%d%d" , &value[i] , &save_col[i]);
    int a , b;
    for(i = 1 ; i < n ; ++i)
        scanf("%d%d" , &a , &b) , addedge(a , b) , addedge(b , a);
    dfs_build(1 , -1);
    dfs_create(1 , 1);
    for(i = 1 ; i <= Maxcol ; ++i)
    {
        root[i] = ++ind;
        dl(ind) = 1;
        dr(ind) = id;
    }
    for(i = 1 ; i <= n ; ++i)
        modify(p_id[i] , value[i] , root[save_col[i]]);
    char s[10];
    while(ask--)
    {
        scanf("%s%d%d" , s , &a , &b);
        if (s[0] == 'C' && s[1] == 'C')
        {
            modify(p_id[a] , 0 , root[save_col[a]]);
            modify(p_id[a] , value[a] , root[b]);
            save_col[a] = b;
        }
        if (s[0] == 'C' && s[1] == 'W')
        {
            modify(p_id[a] , b , root[save_col[a]]);
            value[a] = b;
        }
        if (s[0] == 'Q' && s[1] == 'S')
            printf("%d\n" , ask_add(a , b , save_col[a]));
        if (s[0] == 'Q' && s[1] == 'M')
            printf("%d\n" , ask_max(a , b , save_col[a]));
    }
    return 0;
}

comments powered by Disqus