Trie&AC自动机应用

字符串处理

Posted by JU on May 15, 2018

luoguP3808

    #include <algorithm>
    #include <iostream>
    #include <cstring>
    #include <string>
    #include <cstdio>
    #include <cmath>
    #include <queue>
    #define sc(p) scanf("%d",&p)
    #define pr(p) printf("%d",p)
    const int oo=0x7FFFFFFF;
    const int o=0x3FFFFFFF;
    const int N=1001000;
    using namespace std;
    struct NODE
    {
        int cnt,nx[30],fail,fa;
    }t[N];
    int tot=0,root=0;
    void insert (int now,string sb)
    {
        for (int i=0; i<sb.length(); i++)
        {
            if (!t[now].nx[sb[i]-'a'])
            {
                t[now].nx[sb[i]-'a']=++tot;
                t[t[now].nx[sb[i]-'a']].fa=now;
            }
            now=t[now].nx[sb[i]-'a'];
        }
        t[now].cnt++;
    }
    int search (int now,string sb)
    {
        for (int i=0; i<sb.length(); i++)
        {
            now=t[now].nx[sb[i]-'a'];
            if (!now) return 0;
        }
        return t[now].cnt;
    }
    queue<int> q;
    void init (int now)
    {
        for (int i=0; i<26; i++)
        if (t[root].nx[i])
        {
            t[t[root].nx[i]].fail=root;
            q.push (t[root].nx[i]);
        }
        while (!q.empty())
        {
            int cur=q.front();
            q.pop();
            for (int i=0; i<26; i++)
            {
                int nx=t[cur].nx[i];
                if (nx)
                {
                    t[nx].fail=t[t[cur].fail].nx[i];
                    q.push (nx);
                }
                else
                t[cur].nx[i]=t[t[cur].fail].nx[i];
            }
        }
    }
    int query (string sb)
    {
        int l=sb.length ();
        int now=root,ans=0;
        for (int i=0; i<l; i++)
        {
            now=t[now].nx[sb[i]-'a'];
            for (int g=now; g&&t[g].cnt!=-1; g=t[g].fail)
            ans+=t[g].cnt,t[g].cnt=-1;
        }
        return ans;
    }
    string sb;
    char sssb[2010000];
    int main()
    {
        //freopen (".in","r",stdin);
        //freopen (".out","w",stdout);
        int n; sc(n);
        for (int i=1; i<=n; i++)
        { string sb; scanf ("%s",sssb); sb=sssb; insert (root,sb); }
        init (root);
        scanf ("%s",sssb); sb=sssb;
        int ans=query (sb);
        cout<<ans;
        return 0;
    }

SMOJ1768 阿狸的打字机

构Trie & AC机

构图较为简单,此处不再赘述。

    int now=root,sumcnt=0;
    for (int i=0; i<l; i++)
    {
            if (sb[i]>='a'&&sb[i]<='z')
            {
                    if (!t[now].nx[sb[i]-'a'])
                    {
                            t[now].nx[sb[i]-'a']=++tot;
                            t[t[now].nx[sb[i]-'a']].fa=now;
                            t[now].vis[sb[i]-'a']=tot;
                            //要记录两次 
                    }
                    now=t[now].nx[sb[i]-'a'];
            }
            if (sb[i]=='B') now=t[now].fa;
            if (sb[i]=='P')
            {
                    t[now].cnt++;
                    t[now].id=++sumcnt;
                    wei[sumcnt]=tot;
                    //第sumcnt个串的结尾编号位置 
            }
    }
    init (root);

询问

本人使用了一个vector&一个struct来保存询问。

    for (int i=1; i<=n; i++)
{
	sc(ask[i].x),sc(ask[i].y),ask[i].id=i;
	ask[i].w=query[ask[i].y].size();
	query[ask[i].y].push_back (ask[i].x);
}

主要做法

首先构造fail树,求出每个节点的dfn与其子树的dfn范围;再对原trie进行dfs,每访问到一个节点,就将它的dfn序插入序列(树状数组维护),离开时则删去;每次访问到一个串的末尾,就调出与其相关的询问,计算并储存答案。

正确性

可以证明,在fail树中,x子树的每个节点所表示的字符串都包括了x串。此时,只需统计x子树中包含多少个Triey路径上的点即为答案。

构fail树

通过init求完所有点的fail后,将每个点与fail连一条边,形成fail图。

    for (int i=1; i<=tot; i++)
add (t[i].fail,i);
faildfs (root);

DFS:fail树

通过一次对fail树的dfs,求出每个点的dfs序(dfn)与其子树的dfn范围(dfn[i]~high[i])。

    int dfn[N]={0},sumdfn=0,high[N]={0};
    void faildfs (int u)
    {
            dfn[u]=++sumdfn; high[u]=dfn[u];
            for (int kb=top[u]; kb; kb=lb[kb].nx)
            {
                    int v=lb[kb].v;
                    if (dfn[v]) continue;
                    faildfs (v);
                    high[u]=max (high[u],high[v]);
            }
    }

DFS:Trie(&树状数组)

按照做法进行即可。

    int bit[N];
    void update (int w,int ad)
    {
            for (int i=w; i<=tot+1; i+=lowbit(i))
            bit[i]+=ad;
    }
    int getsum (int w)
    {
        int ans=0;
        for(int i=w;i;i-=lowbit(i))
            ans+=bit[i];
        return ans;
    }
    vector <int> query[N];
    vector <int> yhf[N];
    int wei[N];
    void tiredfs (int now)
    {
            update (dfn[now],1);
            for (int i=0; i<query[t[now].id].size(); i++)
            {
                    int x=query[t[now].id][i];
                    int q1=getsum (high[wei[x]]);
                    int q2=getsum (dfn[wei[x]]-1);
                    yhf[t[now].id].push_back (q1-q2);
            }
            for (int i=0; i<26; i++)
            {
                    int nx=t[now].nx[i];
                    if (nx) tiredfs (nx);
            }
            update (dfn[now],-1);
    }

输出

    for (int i=1; i<=n; i++)
pr(yhf[ask[i].y][ask[i].w]);
CC 原创文章采用CC BY-NC-SA 4.0协议进行许可,转载请注明:“转载自:Trie&AC自动机应用