学习笔记 - 虚树

初学的时候整个人是懵的
不过总算是弄懂了


『算法概述』

对于一类树上的问题,如果仅有一部分点对答案“有用”(也就是说另一部分“可以不要”),那么我们考虑只存储那些有用的点。这就是虚树的思想。所以为什么它叫虚树?我也不知道……
通常情况下我们把一些点的 LCA 也算作“有用”的点。
虚树可做的题目一般来说(以我对虚树短浅的认知)会限制有用的点的总个数,并且会用多组询问的方式来使一般的方法TLE……


『具体实现』

仿照着其他博客的思路,我觉得用一道例题来讲会更加容易理解——

〔BZOJ 2286 消耗战〕
「题意」
给定一棵n个点的树,每个边有权值表示将它割断的花费,已知其中的k个点是有价值的,现在需要把这些点与点1断开(不连通),求最小花费。
输入时先给定一棵树。
m组数据,每次询问给出k个点,表示它们是有价值的,对于每组数据输出最小花费。
[反正$O(nm)$会超时就对了,保证所有数据的k之和不超过500000]

这道题最基础的做法是对于每一组数据跑一遍DP,显然是 $O(nm)$ 的做法,会 TLE。
那么这道题也算是一道比较特殊的虚树题(某taotao给我说的:这道题必须把根节点1放在虚树里,就不能体现虚树的一些性质)。

「DP部分」

不难想到将点u与根节点割开无非就是割去根节点到u的路径上的一条边,根据这一点我们定义 dp[u] 表示 根节点到u的路径上的最小边,也就是将 u 从根节点割开的最小花费。
那么就可以得到简单的转移式,首先初始值是 $dp[1]=INF$(感性理解就是1不可能与它本身割开),然后转移式:

也就是说要么在 1 到 u 的路径上割掉一条边,要么割断 u 到 v 的边。

「虚树部分」

为什么说“一些点的lca有用”呢?等会在「求解部分」会解释。
虚树的构建方法大概是:

  1. 将有用点按 dfs 序排序;
  2. 创建栈并将第一个点压入;
  3. 枚举下一个点 u (直到枚举完为止);
  4. 如果栈内只有一个点,将 u 压入栈,跳到步骤 3,否则进行下一步;
  5. 求 u 与 栈顶点 的lca;
  6. 如果lca就是栈顶点,跳转到步骤3,否则进行下一步;
  7. 设栈内栈顶点的后面一个点为w;
  8. 如果w的dfs序大于等于lca的dfs序,进行下一步,否则跳转到步骤 10;
  9. 在 w 和 栈顶点 之间连边,并弹出栈顶,跳转到步骤7;
  10. 如果 lca 不是现在的栈顶点,在 lca 和现在栈顶点之间连边;
  11. 弹出栈顶,并将 lca 压入栈;
  12. 压入u,跳转到步骤3;

希望reader们都看懂了,可以把样例拿来自己推一推
这样一棵虚树就构造好了。
注意重置虚树的方法,不要 memset(我已经试过了)

「求解部分」

这里的DFS实际上就是树形DP,而dp[u]表示的并不是DP数组,只是u到根节点到路径上的最小边-既然有reader问到了我还是补一下)
从根节点1开始,在虚树上DFS一遍,如果当前节点是叶子节点,就直接返回它的dp值,否则返回 min{它的dp值,它的所有儿子的返回值之和}
就相当于如果要割断 u 和 v ,要么割断它们的lca,要么分别将它们割断(花费加起来)。所以lca是有用的~

具体一点的话就是:
先判断当前u是否是叶子节点,如果是则说明这一定是题目给出的“有价值”的点,就不得不将它与它的父亲割开——返回 dp[u]
否则判断两种情况: ① 当前点到根节点的路径已经被割断了——也就是 dp[u];② 逐个考虑u的儿子,分别将它们与根节点割开——$\sum_{v\in son of u}DFS(v)$(虽然可能dp[v]表示的割去的边是根节点到u到路径上到边,看似是重复的,但是这样的情况会在①②两种情况取较小值时被排除~)

「至于这道题虚树的用处」

根据DP时的决策,我们发现需要的点无非3种——①根节点,因为这是DP的起点(而且转移时根节点也可能有一定贡献);②“有价值的点”,这会作为虚树的叶子节点,并且限制DP时的转移;③lca,在DP转移时会有两种情况,要么是把lca到根节点的路径割断,要么是把“有价值的点”到lca到路径割断。
那么我们只需要考虑这3种点就可以了。

但是其实这只是一种特例——有一些(大多数)题是不一定要把根节点放在虚树里面的,比如 Codeforces 613D 。

「一些坑」

总的来说虚树题都会出很多个询问,然后我们要对每个询问都建立虚树……于是我们就面临一个问题——怎么重置虚树?
在这里我是用的手写链表储存的邻接表,所以我会在DFS(u)计算出答案之后将u的表头清零(就相当于把与u相连的所有边删掉了),最后再把链表的计数器清零。
切忌用memset(TLE的亲身经历),但是也要注意是否完全清零~


『源代码』

结合代码更容易理解~

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/*Lucky_Glass*/ 
#include<bits/stdc++.h>
using namespace std;
const int N=250000;
typedef long long ll;
int QRead(){
int a=0,b=1;char c=getchar();
while(!('0'<=c && c<='9')) b=(c=='-'? -1:b),c=getchar();
while('0'<=c && c<='9') a=(a<<3)+(a<<1)+c-'0',c=getchar();
return a*b;
}
struct GRAPH{
struct EDGE{
int to,nxt,cst;
EDGE(){}
EDGE(int _to,int _nxt,int _cst):to(_to),nxt(_nxt),cst(_cst){}
}edg[N*2+7];
int adj[N+7],edgtot;
void ReBuild(){
memset(adj,-1,sizeof adj);
edgtot=0;
}
void AddEdge(int u,int v,int cst,bool dir=false){
edg[++edgtot]=EDGE(v,adj[u],cst);adj[u]=edgtot;
if(!dir) edg[++edgtot]=EDGE(u,adj[v],cst),adj[v]=edgtot;
}
}grp,tre;
int dfncnt,n,q,m,statop;
int dfn[N+7],fa[N+7][20],dep[N+7],pnt[N+7],sta[N+7];
ll dp[N+7];
void DFS(int u,int pre){
fa[u][0]=pre;dep[u]=dep[pre]+1;dfn[u]=++dfncnt;
for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=grp.adj[u];i!=-1;i=grp.edg[i].nxt){
int v=grp.edg[i].to;
if(v==pre) continue;
dp[v]=min(dp[u],1ll*grp.edg[i].cst);
DFS(v,u);
}
}
int LCA(int u,int v){
if(dep[u]>dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[v][i]]>=dep[u])
v=fa[v][i];
if(u==v) return u;
for(int i=19;i>=0;i--)
if(fa[v][i]!=fa[u][i])
v=fa[v][i],
u=fa[u][i];
return fa[u][0];
}
void Insert(int u){
if(statop==1){sta[++statop]=u;return;}
int lca=LCA(u,sta[statop]);
if(lca==sta[statop]) return;
while(statop>1 && dfn[sta[statop-1]]>=dfn[lca])
tre.AddEdge(sta[statop-1],sta[statop],0,true),
sta[statop--]=0;
if(lca!=sta[statop]) tre.AddEdge(lca,sta[statop],0,true),sta[statop]=lca;
sta[++statop]=u;
}
ll DP(int u){
if(tre.adj[u]==-1) return dp[u];
ll sum=0;
for(int i=tre.adj[u];i!=-1;i=tre.edg[i].nxt)
sum+=DP(tre.edg[i].to);
tre.adj[u]=-1;
return min(sum,dp[u]);
}
bool cmp(int a,int b){return dfn[a]<dfn[b];}
int main(){
grp.ReBuild();
tre.ReBuild();
n=QRead();
for(int i=1,u,v,cst;i<n;i++)
u=QRead(),v=QRead(),cst=QRead(),
grp.AddEdge(u,v,cst);
dp[1]=(1ll<<60);
DFS(1,0);
q=QRead();
while(q--){
m=QRead();
for(int i=0;i<m;i++) pnt[i]=QRead();
sort(pnt,pnt+m,cmp);
tre.edgtot=0;
sta[statop=1]=1;
for(int i=0;i<m;i++) Insert(pnt[i]);
while(statop>1)
tre.AddEdge(sta[statop-1],sta[statop],0,true),
sta[statop--]=0;
printf("%lld\n",DP(1));
}
return 0;
}


The End

Thanks for reading!

Email: lucky_glass@foxmail.com ,欢迎提问~

0%