codevs 1228 苹果树 树链剖分讲解

  • 时间:
  • 出处:跟我学网络
  • 作者:
  • 浏览:737

标签:树链剖分   bsp   其他   problem   查询   stream   png   const   ble   

 

题目:codevs 1228 苹果树

链接:http://codevs.cn/problem/1228/

 

看了这么多树链剖分的解释,几个小时后总算把树链剖分弄懂了。

树链剖分的功能:快速修改,查询树上的路径。

 

比如一颗树

 

首先,我们要把树剖分成树链。定义:

fa[x]是x节点的上一层节点(就是他的爸爸)。

deep[x]是x节点的深度。

num[x]是x节点下面的子节点的数量(包括自己)

son[x]重儿子:一个节点的儿子的num[x]值最大的节点。其他的儿子都是轻儿子。

重链:重儿子连接在一起的路径,比如上图粗线就是重链(叶节点也是重链,只不过它只有一个点)。

     重链之间是用一条轻链边连接的。

top[x]是每条重链的根节点,即是上图中的红色点。

tree[x]是数上节点在线段树上的编号

ftree[x]是线段树上节点在原来树的节点号

现在把它放到线段树里,从根节点开始编号为1,沿着重链走,每走到一个节点给它编号(可以用一个topa变量记录下一个编号),重链走完了走轻链。如图所示就给每条边都编上号了。如果边的长度没有,当然也可以把节点放在线段树上。图总的蓝色数字就是这条边在线段树里的位置,形成了区间,如下图。

 

然后把这个数组组成最终的线段树,就可以控制它的区间了。

可以发现,虽然看上去把树剖分放到线段树上好像打乱了树的顺序,线段树中的点仍然有原来树的影子。比如如果我要访问x节点的子树,那么这个节点的子树的区间就是从tree[x]到tree[x]+num[x]-1(-1是减掉自己这个节点)的区间。

 

我们可以用2个dfs来把剖分的动作实现。

第一个dfs先实现fa[x],deep[x],num[x]的计算,num要在访问完子树之后计算,见代码:

 1 void dfs1(int x)
 2 {
 3     num[x]++;
 4     for(int i=0;i<map[x].size();i++)
 5     {
 6         int dd=map[x][i];
 7         if(dd!=fa[x])
 8         {
 9             fa[dd]=x;
10             deep[dd]=deep[x]+1;
11             dfs1(dd);
12             num[x]+=num[dd];
13         } 
14     }
15 }

 

注释:map是STL的vector,用来储存边。

第二个dfs完成son[x],tree[x],ftree[x]的计算,代码如下:

 1 void dfs2(int x)
 2 {
 3     topa++;
 4     ftree[topa]=x;
 5     A[topa]++;
 6     tree[x]=topa;
 7     int zi=0,mx=0;
 8     for(int i=0;i<map[x].size();i++)
 9     {
10         int dd=map[x][i];
11         if(num[dd]>mx)
12         {
13             mx=num[dd];
14             zi=dd;
15         }
16     }
17     if(zi!=0) dfs2(zi); else return;
18     son[x]=zi;
19     for(int i=0;i<map[x].size();i++)
20     {
21         int dd=map[x][i];
22         if(dd!=zi) dfs2(dd);
23     }
24 }

剖分动作结束,接下来是线段树的事情了。

这里再说一下如何在线段树上操作原树,之前提到过,其实在线段树上也有原来树的结构。

x的子树区间就是tree[x]到tree[x]+num[x]-1。

 

下面来看一下这道题:codevs 1228 苹果树

这是一个最基本的树链剖分。题目中要求计算一颗子树上有苹果多少颗,改变是点修改。因此只要找到那个节点,子树在线段树上的位置,线段树是维护某区间的苹果树数量,查询操作就是一般的线段树查询。

 

代码:

  1 #include<cstdio>
  2 #include<vector>
  3 #include<iostream>
  4 using namespace std;
  5 const int maxn=100010;
  6 
  7 vector<int> map[maxn];
  8 int fa[maxn],n,deep[maxn],num[maxn],topa,A[maxn],tree[maxn],ftree[maxn],son[maxn],sumv[maxn*4],k;
  9 
 10 void dfs1(int x)
 11 {
 12     num[x]++;
 13     for(int i=0;i<map[x].size();i++)
 14     {
 15         int dd=map[x][i];
 16         if(dd!=fa[x])
 17         {
 18             fa[dd]=x;
 19             deep[dd]=deep[x]+1;
 20             dfs1(dd);
 21             num[x]+=num[dd];
 22         } 
 23     }
 24 }
 25 
 26 void dfs2(int x)
 27 {
 28     topa++;
 29     ftree[topa]=x;
 30     A[topa]++;
 31     tree[x]=topa;
 32     int zi=0,mx=0;
 33     for(int i=0;i<map[x].size();i++)
 34     {
 35         int dd=map[x][i];
 36         if(num[dd]>mx)
 37         {
 38             mx=num[dd];
 39             zi=dd;
 40         }
 41     }
 42     if(zi!=0) dfs2(zi); else return;
 43     son[x]=zi;
 44     for(int i=0;i<map[x].size();i++)
 45     {
 46         int dd=map[x][i];
 47         if(dd!=zi) dfs2(dd);
 48     }
 49 }
 50 
 51 void init(int o,int L,int R)
 52 {
 53     if(L==R) sumv[o]=A[L];
 54     else
 55     {
 56         int M=(L+R)/2;
 57         init(o*2,L,M);
 58         init(o*2+1,M+1,R);
 59         sumv[o]=sumv[o*2]+sumv[o*2+1];
 60     }
 61 }
 62 
 63 int y1,y2,p;
 64 void update(int o,int L,int R)
 65 {
 66     if(L==R) sumv[o]=(sumv[o]+1)%2;
 67     else
 68     {
 69         int M=(L+R)/2;
 70         if(p<=M) update(o*2,L,M);
 71         else update(o*2+1,M+1,R);
 72         sumv[o]=sumv[o*2]+sumv[o*2+1];
 73     }
 74 }
 75 
 76 int ans;
 77 void query(int o,int L,int R)
 78 {
 79     if(y1<=L && R<=y2) ans+=sumv[o];
 80     else
 81     {
 82         int M=(L+R)/2;
 83         if(y1<=M) query(o*2,L,M);
 84         if(y2>M) query(o*2+1,M+1,R);
 85     }
 86 }
 87 
 88 int main()
 89 {
 90     cin>>n;
 91     for(int i=1,x,y;i<=n-1;i++)
 92     {
 93         cin>>x>>y;
 94         map[x].push_back(y);
 95     }
 96     deep[1]=1;
 97     dfs1(1);
 98     dfs2(1);
 99     
100     init(1,1,n);
101     
102     cin>>k;
103     for(int i=1,x;i<=k;i++)
104     {
105         char tp;
106         cin>>tp;
107         if(tp==C)
108         {
109             cin>>x;
110             p=tree[x];
111             update(1,1,n);
112         }
113         else
114         {
115             cin>>x;
116             y1=tree[x];
117             y2=y1+num[x]-1;
118             ans=0;
119             query(1,1,n);
120             cout<<ans<<endl;
121         }
122     }
123     return 0;
124 }