虚树

ooliver 发布于 2 小时前 11 次阅读 OI


AI 摘要

当面对海量关键点查询时,暴力求解必然超时。虚树的核心思想是:**只保留关键点及其LCA,在O(k log k)内重构一棵精简树**。本文从“大工程”中的最值总和查询,到“消耗战”的树形DP,再到“世界树”的复杂归属判定,以及“寻宝游戏”的动态维护,带你彻底搞懂虚树的构建与四大经典应用。

P4103 [HEOI2014] 大工程

先建虚树,对于最大值和最小值是很好维护的,对于总和,我们用一个标记记录当前节点作为 lca 的贡献次数,递归到子节点后在加上贡献次数。

C++
#include<bits/stdc++.h>
using namespace std;

#define int long long
#define lc p<<1
#define rc p<<1|1

const int N=1e6+5;

int n,m,k,idx;
int ans1,ans2,ans3;
int a[N],node[N],tag[N],sum[N],mn[N],mx[N];
int fa[N],son[N],dep[N],top[N],siz[N],dfn[N];
vector<int> t[N],g[N];

bool cmp(int x,int y){return dfn[x]<dfn[y];}

void dfs1(int u,int f){
    fa[u]=f,siz[u]=1,dep[u]=dep[f]+1;
    for(int v:g[u]){
        if(v==f) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
    }
}

void dfs2(int u,int tp){
    top[u]=tp,dfn[u]=++idx;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(int v:g[u]){
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}

int lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y]?y:x;
}

void build(){
    int len=0;
    sort(a+1,a+1+k,cmp);
    for(int i=1;i<=k;i++){
        node[++len]=a[i];
        if(i<k){
            int f=lca(a[i],a[i+1]);
            node[++len]=f;
        }
    }
    sort(node+1,node+1+len,cmp);
    len=unique(node+1,node+1+len)-node-1;
    for(int i=1;i<len;i++){
        int f=lca(node[i],node[i+1]);
        t[f].push_back(node[i+1]);
    }
}

void dfs3(int u){
    sum[u]=tag[u];
    for(int v:t[u]){
        dfs3(v);
        sum[u]+=sum[v];
    }
}

void dfs4(int u,int k1){
    if(tag[u]) ans1+=dep[u]*k1,mn[u]=mx[u]=dep[u];
    for(int v:t[u]){
        ans1-=sum[v]*(sum[u]-sum[v])*dep[u];
        dfs4(v,k1+(sum[u]-sum[v]));
        ans2=min(ans2,mn[v]+mn[u]-2*dep[u]),ans3=max(ans3,mx[v]+mx[u]-2*dep[u]);
        mn[u]=min(mn[u],mn[v]),mx[u]=max(mx[u],mx[v]);
        mn[v]=1e18,mx[v]=-1e18;
    }
    t[u].clear();
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++){
        if(i<n){
            int u,v;
            cin>>u>>v;
            g[u].push_back(v);
            g[v].push_back(u);
        }
        mn[i]=1e18,mx[i]=-1e18;
    }
    dfs1(1,0),dfs2(1,1);
    cin>>m;
    while(m--){
        ans1=0,ans2=1e18,ans3=-1e18;
        cin>>k;
        for(int i=1;i<=k;i++) cin>>a[i],tag[a[i]]=1;
        build();
        dfs3(node[1]),dfs4(node[1],0);
        cout<<ans1<<" "<<ans2<<" "<<ans3<<"\n";
        for(int i=1;i<=k;i++) tag[a[i]]=0;
        mn[node[1]]=1e18,mx[node[1]]=-1e18;
    }
    return 0;
}

P2495 【模板】虚树 / [SDOI2011] 消耗战

这个是虚树模板。
建虚树后直接 dp 即可,分两种情况:

  1. 切断当前节点与根节点路径的最小边
  2. 切断所有子树节点与根节点路径的最小边

代码:

C++
#include<bits/stdc++.h>
using namespace std;

#define int long long
const int N=5e5+5;
int n,m,k,idx;
int a[N],node[N],tag[N];
int mn[N],dp[N];
int fa[N],son[N],dep[N],top[N],siz[N],dfn[N];
struct edge{
    int v,w;
};
vector<edge> g[N];
vector<int> t[N];

bool cmp(int x,int y){return dfn[x]<dfn[y];}

void dfs1(int u,int f){
    fa[u]=f,siz[u]=1,dep[u]=dep[f]+1;
    for(edge x:g[u]){
        int v=x.v,w=x.w;
        if(v==f) continue;
        mn[v]=min(mn[u],w);
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
    }
}

void dfs2(int u,int tp){
    top[u]=tp,dfn[u]=++idx;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(edge x:g[u]){
        int v=x.v,w=x.w;
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}

int lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y]?y:x;
}

void build(){
    int len=0;
    sort(a+1,a+1+k,cmp);
    for(int i=1;i<=k;i++){
        node[++len]=a[i];
        if(i<k){
            int f=lca(a[i],a[i+1]);
            node[++len]=f;
        }
    }
    sort(node+1,node+1+len,cmp);
    len=unique(node+1,node+1+len)-node-1;
    for(int i=1;i<len;i++){
        int f=lca(node[i],node[i+1]);
        t[f].push_back(node[i+1]);
    }
}

void dfs(int u){
    int sum=0;
    for(int v:t[u]){
        dfs(v);
        sum+=dp[v],dp[v]=0;
    }
    if(tag[u]) dp[u]=mn[u];
    else dp[u]=min(mn[u],sum);
    t[u].clear();
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v,w;
        cin>>u>>v>>w;
        g[u].push_back({v,w});
        g[v].push_back({u,w});
    }
    mn[1]=LONG_LONG_MAX;
    dfs1(1,0),dfs2(1,1);
    cin>>m;
    while(m--){
        cin>>k;
        for(int i=1;i<=k;i++) cin>>a[i],tag[a[i]]=1;
        build();
        dfs(node[1]);
        cout<<dp[node[1]]<<"\n";
        for(int i=1;i<=k;i++) tag[a[i]]=0;
    }
    return 0;
}

P3233 [HNOI2014] 世界树

先建虚树,现在分两种情况:

  1. 虚树内的节点:做两次 dfs,上下遍历更新两次即可。
  2. 不在虚树上的节点:在第二次 dfs 的边的枚举中单独计算,从子节点倍增跳到父节点,找到边上的临界点即可。

代码:

C++
#include<bits/stdc++.h>
using namespace std;

#define int long long
#define lc p<<1
#define rc p<<1|1

const int N=3e5+5;

int n,m,k,idx;
int a[N],b[N],node[N],tag[N];
int dfn[N],siz[N],fa[N][25],d[N];
int dis[N],p[N],ans[N];
vector<int> t[N],g[N];

bool cmp(int x,int y){return dfn[x]<dfn[y];}

void dfs(int u,int f){
	siz[u]=1,fa[u][0]=f,d[u]=d[f]+1,dfn[u]=++idx;
	for(int i=1;i<=18;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
	for(int v:g[u]){
		if(v==f) continue;
		dfs(v,u);
        siz[u]+=siz[v];
	}
}

int lca(int u,int v){
	if(d[u]<d[v]) swap(u,v);
	for(int i=20;i>=0;i--) if(d[fa[u][i]]>=d[v]) u=fa[u][i];
	if(u==v) return u;
	for(int i=20;i>=0;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}

void build(){
    int len=1;
    node[1]=1,sort(a+1,a+1+k,cmp);
    for(int i=1;i<=k;i++){
        node[++len]=a[i];
        if(i<k){
            int f=lca(a[i],a[i+1]);
            node[++len]=f;
        }
    }
    sort(node+1,node+1+len,cmp);
    len=unique(node+1,node+1+len)-node-1;
    for(int i=1;i<len;i++){
        int f=lca(node[i],node[i+1]);
        t[f].push_back(node[i+1]);
    }
}

void get(int u,int v){
    int x=v,y=v;
    for(int i=20;i>=0;i--) if(d[fa[x][i]]>d[u]) x=fa[x][i];
    ans[p[u]]-=siz[x];
    for(int i=20;i>=0;i--){
        int d1=d[v]-d[fa[y][i]]+dis[v],d2=d[fa[y][i]]-d[u]+dis[u];
        if(d[fa[y][i]]>d[u]&&(d1<d2||(d1==d2&&p[v]<p[u]))) y=fa[y][i];
    }
    ans[p[u]]+=siz[x]-siz[y],ans[p[v]]+=siz[y]-siz[v];
}

void dfs1(int u){
    dis[u]=1e18;
    for(int v:t[u]){
        dfs1(v);
        if(d[v]-d[u]+dis[v]<dis[u]) dis[u]=d[v]-d[u]+dis[v],p[u]=p[v];
        else if(d[v]-d[u]+dis[v]==dis[u]) p[u]=min(p[v],p[u]);
    }
    if(tag[u]) dis[u]=0,p[u]=u;
}

void dfs2(int u){
    for(int v:t[u]){
        if(d[v]-d[u]+dis[u]<dis[v]) dis[v]=d[v]-d[u]+dis[u],p[v]=p[u];
        else if(d[v]-d[u]+dis[u]==dis[v]) p[v]=min(p[v],p[u]);
        get(u,v);
        dfs2(v);
    }
    ans[p[u]]+=siz[u];
    tag[u]=0,t[u].clear();
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
    cin>>m;
    while(m--){
        cin>>k;
        for(int i=1;i<=k;i++) cin>>a[i],b[i]=a[i],tag[a[i]]=1;
        build();
        dfs1(1),dfs2(1);
        for(int i=1;i<=k;i++) cout<<ans[b[i]]<<" ",tag[b[i]]=0,ans[b[i]]=0;
        cout<<"\n";
    }
    return 0;
}

P3320 [SDOI2015] 寻宝游戏

有一个结论:虚树的边权和的两倍等于把所有关键点按 dfn 序排序并对相邻两个节点的距离求和(包括第一个节点和最后一个节点)。

那在这里我们直接用 set 维护排序后的关键点的 dfn 序,动态对答案加减即可。

代码:

C++
#include<bits/stdc++.h>
using namespace std;

#define int long long
const int N=5e5+5;
int n,m,k,idx,ans;
int fa[N],son[N],dep[N],top[N],siz[N],dfn[N],inv[N],dis[N],tag[N];
set<int> s;
set<int>::iterator it;
struct edge{int v,w;};
vector<edge> g[N];

void dfs1(int u,int f){
    fa[u]=f,siz[u]=1,dep[u]=dep[f]+1;
    for(edge x:g[u]){
        int v=x.v,w=x.w;
        if(v==f) continue;
        dis[v]=dis[u]+w;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
    }
}

void dfs2(int u,int tp){
    top[u]=tp,dfn[u]=++idx,inv[idx]=u;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(edge x:g[u]){
        int v=x.v,w=x.w;
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}

int lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y]?y:x;
}

int path(int x,int y){
    return dis[x]+dis[y]-2*dis[lca(x,y)];
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<n;i++){
        int u,v,w;
        cin>>u>>v>>w;
        g[u].push_back({v,w});
        g[v].push_back({u,w});
    }
    dfs1(1,0),dfs2(1,1);
    while(m--){
        int x;
        cin>>x;
        if(!tag[x]) s.insert(dfn[x]);
        int l,r;
        it=s.lower_bound(dfn[x]);
        if(it==s.begin()) l=*--s.end();
        else l=*--it;
        it=s.upper_bound(dfn[x]);
        if(it==s.end()) r=*s.begin();
        else r=*it;
        if(tag[x]) s.erase(dfn[x]);
        tag[x]^=1;
        l=inv[l],r=inv[r];
        if(tag[x]) ans+=path(l,x)+path(r,x)-path(l,r);
        else ans-=path(l,x)+path(r,x)-path(l,r);
        cout<<ans<<"\n";
    }
    return 0;
}