树分治

抽象

Posted by JU on August 7, 2018

LuoguP3806

#include <bits/stdc++.h>
#define pr(p) printf("%d\n",p)
const int oo=2139063143;
const int N=100010;
const int mod=1000000007;
using namespace std;
typedef long long LL;
inline void sc (int &x)
{
    x=0; static int p; p=1; static char c; c=getchar();
    while (!isdigit(c)) { if (c=='-') p=-1; c=getchar(); }
    while ( isdigit(c)) { x=(x<<1)+(x<<3)+(c-48); c=getchar(); }
    x*=p;
}
int n,m,k[N],ans[N];
struct EDGE { int v,w,nx; }lb[N<<1]; int top[N],tot=1;
void add (int u,int v,int w)
{
    lb[++tot].v=v,lb[tot].w=w,lb[tot].nx=top[u],top[u]=tot;
}
int q[N],head,tail;
int dep[N],siz[N],masiz[N]; bool vis[N];
void dfs (int u)
{
    q[++tail]=u; siz[u]=vis[u]=1,masiz[u]=0;
    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v,w=lb[kb].w;
        if (vis[v]) continue;
        dep[v]=dep[u]+w;
        dfs (v);
        siz[u]+=siz[v];
        masiz[u]=max (masiz[u],siz[v]);
    }
    vis[u]=0;
}
bool ly (int A,int B) { return dep[A]< dep[B]; }
int calc (int l,int r,int he)
{
    sort (q+l,q+r+1,ly);
    int an=0;
    for (int i=l,j=r; i< j; i++)
    {
        while (i< j&&dep[q[i]]+dep[q[j]]> he) --j;
        if (i==j) break;
        an+=(bool)(dep[q[i]]+dep[q[j]]==he);
    }
    return an;
}
void work (int u)
{
    head=1,tail=0;
    dfs (u);
    int t=siz[u],mi=masiz[u];
    for (int i=1; i<=tail; i++)
    {
        int v=q[i];
        int mm=max (masiz[v],t-siz[v]);
        if (mm< mi) mi=mm,u=v;
    }

    vis[u]=1;
    head=1,tail=0;
    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v,w=lb[kb].w;
        if (vis[v]) continue;
        dep[v]=w;
        dfs (v);
        for (int i=1; i<=m; i++)
            ans[i]-=calc (head,tail,k[i]);
        head=tail+1;
    }

    dep[u]=0; q[++tail]=u;
    for (int i=1; i<=m; i++)
        ans[i]+=calc (1,tail,k[i]);

    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v;
        if (vis[v]) continue;
        work (v);
    }
}
int main()
{
    //freopen (".in","r",stdin);
    //freopen (".out","w",stdout);
    memset (top,-1,sizeof (top));
    sc(n),sc(m);
    for (int i=1; i< n; i++)
    {
        int u,v,w; sc(u),sc(v),sc(w);
        add (u,v,w),add (v,u,w);
    }
    for (int i=1; i<=m; i++)
        sc(k[i]);
    work (1);
    for (int i=1; i<=m; i++)
        puts(ans[i]?"AYE":"NAY");
    return 0;
}

POJ1741

#include <bits/stdc++.h>
#define pr(p) printf("%d\n",p)
const int oo=2139063143;
const int N=101000;
const int mod=1000000007;
using namespace std;
typedef long long LL;
inline void sc (int &x)
{
    x=0; static int p; p=1; static char c; c=getchar();
    while (!isdigit(c)) { if (c=='-') p=-1; c=getchar(); }
    while ( isdigit(c)) { x=(x<<1)+(x<<3)+(c-48); c=getchar(); }
    x*=p;
}
int n,k;
struct EDGE { int v,nx,w; }lb[N<<1]; int top[N],tot=1;
void add (int u,int v,int w)
{
    lb[++tot].v=v,lb[tot].w=w,lb[tot].nx=top[u],top[u]=tot;
}
int q[N],head,tail;
int dep[N],siz[N],masiz[N]; bool vis[N];
bool ly (int A,int B) { return dep[A]<dep[B]; }
void dfs (int u)
{
    q[++tail]=u; siz[u]=vis[u]=1,masiz[u]=0;
    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v,w=lb[kb].w;
        if (vis[v]) continue;
        dep[v]=dep[u]+w;
        dfs (v);
        siz[u]+=siz[v];
        masiz[u]=max (masiz[u],siz[v]);
    }
    vis[u]=0;
}
int calc (int l,int r)
{
    sort (q+l,q+r+1,ly);//sort是左闭右开区间
    int ans=0;
    for (int i=l,j=r; i< j; i++)
    {
        while (i< j&&dep[q[i]]+dep[q[j]]> k) --j;
        ans+=j-i;
    }
    return ans;
}
int work (int u)
{
    head=1,tail=0;
    dfs (u);
    int t=siz[u],mi=masiz[u];
    for (int i=1; i<=tail; i++)
    {
        int v=q[i];
        int m=max (masiz[v],t-siz[v]);
        if (m< mi) mi=m,u=v;
    }

    vis[u]=1;
    int ans=0;
    head=1,tail=0;
    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v,w=lb[kb].w;
        if (vis[v]) continue;
        dep[v]=w;
        dfs (v);
        ans-=calc (head,tail);
        head=tail+1;
    }

    dep[u]=0; q[++tail]=u;
    ans+=calc (1,tail);

    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v;
        if (vis[v]) continue;
        ans+=work (v);
    }

    return ans;
}
int main()
{
    //freopen (".in","r",stdin);
    //freopen (".out","w",stdout);
    sc(n),sc(k);
    while (n> 0||k> 0)
    {
        memset (top,-1,sizeof (top)),tot=1;
        memset (lb,0,sizeof (lb));
        memset (vis,0,sizeof (vis));
        for (int i=1; i< n; i++)
        {
            int u,v,w; sc(u),sc(v),sc(w);
            add (u,v,w),add (v,u,w);
        }
        pr(work(1));
        sc(n),sc(k);
    }

    return 0;
}

树分治

引入

给出特定条件,问整棵树有多少满足条件的路径。
如:给出一个值K,问整棵树有多少长度在K以内的路径。

点分治

将树上的路径分为两类,通过树根的,和不通过树根的。
通过树根的,路径的两端一定在两棵子树中,反之两端则在同一棵子树中。 仅需解决第一种情况,第二种可以递归处理。

方法

暴力。
DFS出每一个点到根的距离,$O(n)$。
然后算出有多少对点的距离总和小于等于K
当然,结果有一些不合法,比如两个点都在同一棵子树中,这些需要另外计算减掉。
但是,若子树大小减小得太慢,会退化成$O(n^2)$。
因此,要找到一个合适的根(重心),使最大子树的大小最小(小于一半)。
每次大小减半,$O(n\log_{}{n})$。

找重心

解决以u为根的子树重心:

void dfs (int u)
{
    q[++tail]=u; siz[u]=vis[u]=1,masiz[u]=0;
    for (int kb=top[u]; kb!=-1; kb=lb[kb].nx)
    {
        int v=lb[kb].v,w=lb[kb].w;
        if (vis[v]) continue;
        dep[v]=dep[u]+w;
        dfs (v);
        siz[u]+=siz[v];
        masiz[u]=max (masiz[u],siz[v]);
    }
    vis[u]=0;
}//用vis隔开不同的分治层次,之后dfs时就不会越过这个点。

head=1,tail=0;
dfs (u);
int t=siz[u],mi=masiz[u];
for (int i=1; i<=tail; i++)
{
    int v=q[i];
    int m=max (masiz[v],t-siz[v]);
    if (m< mi) mi=m,u=v;
}

处理通过重心路径

如果一条路径穿过重心,那么两端在不同的子树中。
每访问一棵子树,就把它的信息记录下来。
当访问新的子树的时候,和已经访问的子树的节点之间两两配对。
这样可以保证每一对端点都在两棵不同的子树中,而且只会算一次。
栗:条件为路径长度≤K,当前点到重心距离为L,算出有多少点到重心距离小于等于K-L即可。
数据结构可解决(树状数组)。
或排序后从两端各设指针相向移动。

int calc (int l,int r)
{
    sort (q+l,q+r+1,ly);//sort是左闭右开区间
    int ans=0;
    for (int i=l,j=r; i< j; i++)
    {
        while (i< j&&dep[q[i]]+dep[q[j]]> k) --j;
        //注意有些i是不能等于j的,可能要break
        ans+=j-i;
    }
    return ans;
}

更改与清除

不能用memset,会变$O(n^2)$。 清除时,应逐个撤销修改。

边分治

边分治

点分治树与边分治树

点分治树与边分治树

%%%LZZ

CC 原创文章采用CC BY-NC-SA 4.0协议进行许可,转载请注明:“转载自:树分治