求字符串中回文子串的个数(回文树详解)
写法一:
#include<stdio.h> #include<iostream> #include<algorithm> #include<string.h> #include<vector> #include<cmath> #include<string> #include<map> #include<queue> using namespace std; const int MAXN = 1005; struct node { int next[26];//指向在当前节点串的基础上左右两边同时添加相同字符后形成的回文串 int len;//此节点串的长度 int sufflink;//此节点的最长后缀回文串的指针 int num;//此节点中的回文子串的个数 }; int len; char s[MAXN]; node tree[MAXN]; int num; // node 1 - root with len -1, node 2 - root with len 0 int suff; // max suffix palindrome long long ans; bool addLetter(int pos) { int cur = suff;//当前以pos-1结尾的最长后缀回文子串的节点标号(即p的长度) int curlen; int let = s[pos] - a; //沿着后缀链接边找到满足xAx形式的最长后缀回文子串 while (true) { curlen = tree[cur].len;//当前A的长度 //判断xAx的两个x处的的字母是否都是x if (pos - 1 - curlen >= 0 && s[pos - 1 - curlen] == s[pos]) break; cur = tree[cur].sufflink;//沿着后缀链接边找 (节点下标) } //找到xAx后接着找此节点是否存在 //注意:节点表示A的next指针指向xAx,每个节点表示的是一个回文串 //xAx即为 以pos下标结尾的最长后缀回文子串的标号 if (tree[cur].next[let]) { suff = tree[cur].next[let];//保留当前以pos下标结尾的最长后缀回文子串的标号(已经有的) return false; } //每增加一个节点,num++ num++; suff = num;//保留当前以pos下标结尾的最长后缀回文子串的标号(新增加的) tree[num].len = tree[cur].len + 2;//增加了连个x tree[cur].next[let] = num;// 节点表示A的next指针指向xAx //判断A是都对应长度为-1的根 if (tree[num].len == 1) { tree[num].sufflink=2;//指向长度为0的根 tree[num].num = 1; return true; } //接下来找tree[num].sufflink while (true) { cur = tree[cur].sufflink; curlen = tree[cur].len; if (pos - 1 - curlen >= 0 && s[pos - 1 - curlen] == s[pos]) { tree[num].sufflink=tree[cur].next[let]; //tree[cur].next[let]--->xBx break; } } //求num节点中回文串的个数 //就是回文串xBx的个数加一(多了一个xAx) tree[num].num = 1 + tree[tree[num].sufflink].num; return true; } void initTree() { num = 2; suff = 2; tree[1].len = -1; tree[1].sufflink = 1;//长度为-1的根 tree[2].len = 0; tree[2].sufflink = 1;//长度为0的根 } int main() { scanf("%s", s); len = strlen(s); initTree(); for (int i = 0; i < len; i++) { addLetter(i); ans += tree[suff].num;//加上每个节点的回文串个数 } cout << ans << endl; return 0; }
写法二:
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #define ri register int using namespace std; const int maxn = 300010; int n, m, tot, cnt[maxn], len[maxn], fail[maxn], last, son[maxn][27], cur; char s[maxn]; long long ans; //添加节点 inline int new_node(ri x) { //更新len,cnt len[tot] = x; cnt[tot] = 0; return tot++; } //找到最长后缀节点 inline int get_fail(ri x, ri n) { while(s[n-len[x]-1] != s[n]) x = fail[x]; return x; } // inline void init() { scanf("%s", s+1); n = strlen(s+1); //初始化两个根节点 new_node(0); new_node(-1); fail[0] = 1; last = 0; } int main() { init(); for(ri i=1;i<=n;i++){ ri x = s[i] - a; cur = get_fail(last, i);//找到最长后缀回文串 if(!son[cur][x]){//没有此回文串 ri nw = new_node(len[cur]+2); fail[nw] = son[get_fail(fail[cur],i)][x]; son[cur][x] = nw; } last = son[cur][x]; cnt[last]++; } //计算出来的cnt[i]表示i节点回文子串出现的次数 for(int i=tot-1;i>=0;i--) cnt[fail[i]] += cnt[i]; for(int i=2;i<tot;i++) { ans+=cnt[i]; } cout<<ans<<endl; return 0; }