线段树分治

用户头像 发布于 9 天前 58 次阅读 OI


线段树分治

线段树分治,顾名思义,就是用线段树进行分治。

P5787 【模板】线段树分治 / 二分图

首先,我们知道二分图的一个充要条件是没有奇数环
所以,这道题中,线段树维护的是每个时间的区间内存在的边。
而这里使用染色法判断二分图显然时间复杂度会超,
所以需要用到可撤销并查集,即用栈维护每次并查集的操作。
根据以上几点得出以下代码:

#include<bits/stdc++.h>
using namespace std;

#define lc (p<<1)
#define rc (p<<1|1)

const int N=4e6+5;
struct node{
    int x,y,hy;
};
stack<node> st;
vector<pair<int,int>> tr[N];
int n,m,k;
int f[N],siz[N],ans[N];

void insert(int p,int l,int r,int x,int y, pair<int,int> e){
    if(x>r || y<l) return;
    if(x<=l && r<=y) return tr[p].push_back(e);
    int mid=l+r>>1;
    insert(lc,l,mid,x,y,e);
    insert(rc,mid+1,r,x,y,e);
}

int find(int x){
    while(x!=f[x]) x=f[x];
    return f[x];
}

void merge(int x,int y){
    x=find(x);
    y=find(y);
    if(siz[x]>siz[y]) swap(x,y);
    st.push({x,y,siz[y]});
    f[x]=y;
    siz[y]+=(siz[x]==siz[y]);
}

void check(int p,int l,int r){
    int flag=0;
    int now=st.size();
    for(int i=0;i<tr[p].size();i++){
        pair<int,int> e=tr[p][i];
        merge(e.first,e.second+n);
        merge(e.second,e.first+n);
        if(find(e.first)==find(e.second)){
            flag=1;
            break;
        }
    }
    if(!flag){
        if(l==r) ans[l]=1;
        else{
            int mid=l+r>>1;
            check(lc,l,mid);
            check(rc,mid+1,r);
        }
    }
    while(st.size()>now){
        node t=st.top();
        st.pop();
        f[t.x]=t.x;
        siz[t.y]=t.hy;
    }
}

int main(){
    cin>>n>>m>>k;
    for(int i=1;i<=2*n;i++) f[i]=i;
    for(int i=1;i<=m;i++){
        int x,y,l,r;
        cin>>x>>y>>l>>r;
        insert(1,1,k,l+1,r,{x,y});
    }
    check(1,1,k);
    for(int i=1;i<=k;i++){
        if(ans[i]) cout<<"Yes\n";
        else cout<<"No\n";
    }
    return 0;
}

P5227 [AHOI2013] 连通图

可以把每一次询问看作一个时间点,
也就是说当前时间点该边不存在,
那么就可以算出改变存在的时间段,
就可以套用线段树合并的模板。
代码:

#include<bits/stdc++.h>
using namespace std;

#define lc (p<<1)
#define rc (p<<1|1)

const int N=4e6+5;
struct node{
    int x,y,hy;
};
stack<node> st;
vector<pair<int,int> > tr[N];
int n,m,k;
int f[N],siz[N],ans[N],bgn[N],u[N],v[N];

void insert(int p,int l,int r,int x,int y, pair<int,int> e){
    if(x>r || y<l) return;
    if(x<=l && r<=y) return tr[p].push_back(e);
    int mid=l+r>>1;
    insert(lc,l,mid,x,y,e);
    insert(rc,mid+1,r,x,y,e);
}

int find(int x){
    while(x!=f[x]) x=f[x];
    return f[x];
}

void merge(int x,int y){
    x=find(x);
    y=find(y);
    if(x==y) return;
    if(siz[x]>siz[y]) swap(x,y);
    st.push({x,y,siz[y]});
    f[x]=y;
    siz[y]+=siz[x];
}

void check(int p,int l,int r){
	int flag=1;
    int now=st.size();
    for(int i=0;i<tr[p].size();i++){
        pair<int,int> e=tr[p][i];
        merge(e.first,e.second);
    }
	if(l==r){
    	if(siz[find(1)]!=n) flag=0;
    	ans[l]=flag;
	}
    else{
        int mid=l+r>>1;
        check(lc,l,mid);
        check(rc,mid+1,r);
    }
    while(st.size()>now){
        node t=st.top();
        st.pop();
        f[t.x]=t.x;
        siz[t.y]=t.hy;
    }
}

int main(){
	ios::sync_with_stdio(0);
	cin.tie(0),cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++) f[i]=i,siz[i]=1;
    for(int i=1;i<=m;i++) cin>>u[i]>>v[i];
    cin>>k;
    for(int i=1;i<=k;i++){
    	int c;
    	cin>>c;
    	while(c--){
    		int x;
    		cin>>x;
    		if(bgn[x]+1<=i-1) insert(1,1,k,bgn[x]+1,i-1,{u[x],v[x]});
    		bgn[x]=i;
		}
	}
	for(int i=1;i<=m;i++) if(bgn[i]<k) insert(1,1,k,bgn[i]+1,k,{u[i],v[i]});
    check(1,1,k);
    for(int i=1;i<=k;i++){
        if(ans[i]) cout<<"Connected\n";
        else cout<<"Disconnected\n";
    }
    return 0;
}

P4219 [BJOI2014] 大融合

不难发现,答案即把询问的边删除后两个顶点所在连通块大小的乘积,
其余部分与上一题一模一样,
代码:

#include<bits/stdc++.h>
using namespace std;

#define lc (p<<1)
#define rc (p<<1|1)

const int N=4e6+5;
struct node{
    int x,y,hy;
};
stack<node> st;
vector<pair<int,int> > tr[N];
map<pair<int,int>,int> mp;
int n,q;
int f[N],siz[N],ans[N],uu[N],vv[N];
char op1[N];
int u1[N],v1[N];

void insert(int p,int l,int r,int x,int y, pair<int,int> e){
    if(x>r||y<l) return;
    if(x<=l&&r<=y) return tr[p].push_back(e);
    int mid=(l+r)>>1;
    insert(lc,l,mid,x,y,e);
    insert(rc,mid+1,r,x,y,e);
}

int find(int x){
    while(x!=f[x]) x=f[x];
    return f[x];
}

void merge(int x,int y){
    x=find(x);
    y=find(y);
    if(x==y) return;
    if(siz[x]>siz[y]) swap(x,y);
    st.push({x,y,siz[y]});
    f[x]=y;
    siz[y]+=siz[x];
}

void solve(int p,int l,int r){
    int now=st.size();
    for(int i=0;i<tr[p].size();i++){
        pair<int,int> e=tr[p][i];
        merge(e.first,e.second);
    }
    if(l==r){
        if(uu[l]&&vv[l]){
            int x=find(uu[l]), y=find(vv[l]);
            if(x!=y) ans[l]=siz[x]*siz[y];
        }
    }
    else{
        int mid=l+r>>1;
        solve(lc,l,mid);
        solve(rc,mid+1,r);
    }
    while(st.size()>now){
        node t=st.top();
        st.pop();
        f[t.x]=t.x;
        siz[t.y]=t.hy;
    }
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n>>q;
    for(int i=1;i<=n;i++){
        f[i]=i;
        siz[i]=1;
    }
    int cnt=0;
    for(int i=1;i<=q;i++){
        cin>>op1[i]>>u1[i]>>v1[i];
        if(op1[i]=='Q') cnt++;
    }
    int time=0;
    for(int i=1;i<=q;i++){
        char op=op1[i];
        int u=u1[i], v=v1[i];
        if(u>v) swap(u,v);
        if(op=='A'){
            mp[{u,v}]=time+1;
        }else if(op=='Q'){
            time++;
            uu[time]=u;
            vv[time]=v;
            if(mp.count({u,v})){
                int stt=mp[{u,v}];
                if(stt<=time-1){
                    insert(1,1,cnt,stt,time-1,{u,v});
                }
                mp[{u,v}]=time+1;
            }
        }
    }
    time=cnt;
    for(auto &it:mp){
        pair<int,int> edge=it.first;
        int stt=it.second;
        if(stt<=time){
            insert(1,1,cnt,stt,time,edge);
        }
    }
    
    solve(1,1,cnt);
    
    for(int i=1;i<=cnt;i++) cout<<ans[i]<<endl;
    return 0;
}

P2147 [SDOI2008] 洞穴勘测

同理,模板套路题。
代码:

#include<bits/stdc++.h>
using namespace std;

#define lc (p<<1)
#define rc (p<<1|1)

const int N=2e5+5, M=1e4+5;
struct node{
    int x,y,hy;
};
stack<node> st;
vector<pair<int,int> > tr[N<<2];
int n,m;
int f[M],siz[M],ans[N];

struct Query{
    char op[10];
    int u,v;
}q[N];

map<pair<int,int>,int> mp;

void insert(int p,int l,int r,int x,int y, pair<int,int> e){
    if(x>r || y<l) return;
    if(x<=l && r<=y){
        tr[p].push_back(e);
        return;
    }
    int mid=l+r>>1;
    insert(lc,l,mid,x,y,e);
    insert(rc,mid+1,r,x,y,e);
}

int find(int x){
    while(x!=f[x]) x=f[x];
    return f[x];
}

void merge(int x,int y){
    x=find(x);
    y=find(y);
    if(x==y) return;
    if(siz[x]>siz[y]) swap(x,y);
    st.push({x,y,siz[y]});
    f[x]=y;
    siz[y]+=siz[x];
}

void solve(int p,int l,int r){
    int now=st.size();
    for(auto e:tr[p]){
        merge(e.first,e.second);
    }
    if(l==r){
        if(q[l].op[0]=='Q'){
            if(find(q[l].u)==find(q[l].v)){
                ans[l]=1;
            }
            else ans[l]=0;
        }
    }
    else{
        int mid=l+r>>1;
        solve(lc,l,mid);
        solve(rc,mid+1,r);
    }
    while(st.size()>now){
        node t=st.top();
        st.pop();
        f[t.x]=t.x;
        siz[t.y]=t.hy;
    }
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    
    cin>>n>>m;
    for(int i=1;i<=n;i++) f[i]=i,siz[i]=1;
    
    for(int i=1;i<=m;i++){
        cin>>q[i].op>>q[i].u>>q[i].v;
        if(q[i].u>q[i].v) swap(q[i].u,q[i].v);
        if(q[i].op[0]=='C'){
            mp[{q[i].u,q[i].v}]=i;
        }
        else if(q[i].op[0]=='D'){
            int start=mp[{q[i].u,q[i].v}];
            mp.erase({q[i].u,q[i].v});
            if(start<=i-1){
                insert(1,1,m,start,i-1,{q[i].u,q[i].v});
            }
        }
    }
    for(auto it:mp){
        int start=it.second;
        if(start<=m){
            insert(1,1,m,start,m,{it.first.first,it.first.second});
        }
    }
    solve(1,1,m);
    for(int i=1;i<=m;i++){
        if(q[i].op[0]=='Q'){
            if(ans[i]) cout<<"Yes\n";
            else cout<<"No\n";
        }
    }
    return 0;
}