题目大意
一棵 /(n(1/le n/le10^5)/) 个节点的树,每个树上有一个颜色值 /(c_i(1/le c_i/le10)/) 。求树上本质不同的路径数,两条路径本质不同当且仅当路径上形成的颜色序列本质不同,保证度数为 /(1/) 的节点数量 /(<20/) 。
思路
如果只考虑从上到下的路径,那么把这棵树当成一颗 /(trie/) 建立广义 /(sam/) 即可解决,但是显然不只有从上到下的路径,于是我们有一个结论,即一颗无根树上任意一条路径必定可以在以某个叶节点为根时,变成一条从上到下的路径,于是我们枚举所有的度数为 /(1/) 的节点为根,将此时的树当做一棵新的 /(trie/) ,题目的限制保证了这样的 /(trie/) 不超过 /(20/) 棵,将所有这些 /(trie/) 进行合并后建立广义 /(sam/) 求解即可。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
using LL = long long;
using LD = long double;
using ULL = unsigned long long;
using PII = pair<int, int>;
using TP = tuple<int, int, int>;
#define all(x) x.begin(),x.end()
#define mst(x,v) memset(x,v,sizeof(x))
#define mk make_pair
//#define int LL
//#define lc P*2
//#define rc P*2+1
#define endl '/n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 1000000007;
const LL mod = 998244353;
const int maxn = 2000010;
int A[maxn], N, M, deg[maxn];
vector<int>G[maxn];
void add_edge(int from, int to)
{
G[from].push_back(to);
G[to].push_back(from);
}
int trie[maxn][26], c[maxn * 26], cnt = 1;
string SS[maxn];
int insert(int ch, int pos)
{
if (!trie[pos][ch])
{
trie[pos][ch] = ++cnt, c[cnt] = ch;
return cnt;
}
return trie[pos][ch];
}
int tot = 1, lst[maxn * 3], pos[maxn * 3];
struct Node {
int len, fa;
int ch[26];
bool isnp;
}sam[maxn * 3];
int extend(int last, char c)
{
int p = last, np = ++tot;
sam[np].len = sam[p].len + 1;
sam[np].isnp = true;
for (; p && !sam[p].ch[c]; p = sam[p].fa)
sam[p].ch[c] = np;
if (!p)
sam[np].fa = 1;
else
{
int q = sam[p].ch[c];
if (sam[q].len == sam[p].len + 1)
sam[np].fa = q;
else
{
int nq = ++tot;
sam[nq] = sam[q], sam[nq].len = sam[p].len + 1;
sam[nq].isnp = false;
sam[q].fa = sam[np].fa = nq;
for (; p && sam[p].ch[c] == q; p = sam[p].fa)
sam[p].ch[c] = nq;
}
}
return np;
}
void bfs()
{
queue<int>que;
pos[1] = 1;
for (int i = 0; i < 26; i++)
{
if (trie[1][i])
que.push(trie[1][i]), lst[trie[1][i]] = pos[1];
}
while (!que.empty())
{
int v = que.front();
que.pop();
pos[v] = extend(lst[v], c[v]);
for (int i = 0; i < 26; i++)
{
if (trie[v][i])
que.push(trie[v][i]), lst[trie[v][i]] = pos[v];
}
}
}
void dfs(int v, int p, int ppos)
{
int tmp = insert(A[v], ppos);
for (auto& to : G[v])
{
if (to == p)
continue;
dfs(to, v, tmp);
}
}
void solve()
{
LL ans = 0;
for (int i = 1; i <= N; i++)
{
if (deg[i] == 1)
dfs(i, 0, 1);
}
bfs();
for (int i = 2; i <= tot; i++)
ans += sam[i].len - sam[sam[i].fa].len;
cout << ans << endl;
}
signed main()
{
IOS;
cin >> N >> M;
for (int i = 1; i <= N; i++)
cin >> A[i];
int u, v;
for (int i = 1; i < N; i++)
cin >> u >> v, add_edge(u, v), deg[v]++, deg[u]++;
solve();
return 0;
}
原创文章,作者:506227337,如若转载,请注明出处:https://blog.ytso.com/tech/pnotes/273593.html