[Solution][AC自动机]BZOJ2754

【思路】

不断沿着last指针向上(即当前字符串的最长模式串后缀)计数。

由于字符集太大要使用map,对于一个节点可能会有多个模式串,使用vector维护。

【代码】

#include<cstdio>
#include<cctype>
#include<cstring>
#include<queue>
#include<map>
#include<vector>
using namespace std;
inline int getc()
{
    const int L = 1 << 15;
    static char buf[L] , *S = buf , *T = buf;
    if (S == T)
    {
        T = (S = buf) + fread(buf , 1 , L , stdin);
        if (S == T)
            return EOF;
    }
    return *S++;
}
inline int getint()
{
    static char c;
    while(!isdigit(c = getc()));
    int tmp = c - '0';
    while(isdigit(c = getc()))
        tmp = (tmp << 1) + (tmp << 3) + c - '0';
    return tmp;
}
map<int , int> ch[100010];
queue<int> q;
vector<int> exist[100010];
int ind , fail[100010] , last[100010] , num[100010] , path[100010];
int l1[20010] , l2[20010] , save[200010] , now;
int ask[50010] , name[20010];
void work(int ins , int cur)
{
    if (ins && path[ins] != cur)
    {
        path[ins] = cur;
        name[cur] += exist[ins].size();
        for(vector<int>::iterator it = exist[ins].begin() ; it != exist[ins].end() ; ++it)
            ++ask[*it];
        work(last[ins] , cur);
    }
}
void go(int len ,  int cur)
{
    register int i , nownum , tmp;
    tmp = 0;
    for(i = 1 ; i <= len ; ++i)
    {
        nownum = save[++now];
        while(tmp && ch[tmp].find(nownum) == ch[tmp].end())
            tmp = fail[tmp];
        if (ch[tmp].find(nownum) != ch[tmp].end())
        {
            tmp = ch[tmp][nownum];
            work(tmp , cur);
        }
    }
}
int main()
{
    int n , m;
    n = getint() , m = getint();
    register int i , j;
    for(i = 1 ; i <= n ; ++i)
    {
        l1[i] = getint();
        for(j = 1 ; j <= l1[i] ; ++j)
            save[++now] = getint();
        l2[i] = getint();
        for(j = 1 ; j <= l2[i] ; ++j)
            save[++now] = getint();
    }
    int len , x , tmp;
    for(i = 1 ; i <= m ; ++i)
    {
        len = getint();
        tmp = 0;
        for(j = 1 ; j <= len ; ++j)
        {
            x = getint();
            if (ch[tmp].find(x) == ch[tmp].end())
            {
                ch[tmp][x] = ++ind;
                num[ind] = x;
            }
            tmp = ch[tmp][x];
        }
        exist[tmp].push_back(i);
    }
    for(map<int , int> :: iterator it = ch[0].begin() ; it != ch[0].end() ; ++it)
        q.push(it->second);
    int u , v , r;
    while(!q.empty())
    {
        r = q.front();
        q.pop();
        for(map<int , int> :: iterator it = ch[r].begin() ; it != ch[r].end() ; ++it)
        {
            u = it->second;
            v = fail[r];
            while(v && ch[v].find(num[u]) == ch[v].end())
                v = fail[v];
            if (ch[v].find(num[u]) == ch[v].end())
                fail[u] = 0;
            else
                fail[u] = ch[v][num[u]];
            last[u] = exist[fail[u]].size() ? fail[u] : last[fail[u]];
            q.push(u);
        }
    }
    now = 0;
    for(i = 1 ; i <= n ; ++i)
    {
        go(l1[i] , i);
        go(l2[i] , i);
    }
    for(i = 1 ; i <= m ; ++i)
        printf("%d\n" , ask[i]);
    printf("%d" , name[1]);
    for(i = 2 ; i <= n ; ++i)
        printf(" %d" , name[i]);
    return 0;
}

comments powered by Disqus