前言

由于本人实例还不能到cf青名,用不上jiangly哥哥的代码,暂时自己结合网上的代码和自己的理解写一个模板,日后再更新。

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// O2 O3 优化
#pragma GCC optimize("O2")
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")

// 头文件
#include <bits/stdc++.h>

/*
// 如果是clang或者gcc 则使用以下头文件
#include <iostream>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <string>
#include <cmath>
#include <functional>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <climits>
*/

using namespace std;

// 常用宏定义
#define ll long long
#define ull unsigned long long
#define pii pair<int, int>
#define endl '\n'
#define pnt(x) cout<<#x<<'='<<(x)<<endl
#define pnt2(x, y) cout<<#x<<'='<<(x)<<','<<#y<<'='<<(y)<<endl

/*
// 如果想要快速输出stl的线性容器 可以使用以下代码
template <typename T1, typename T2>
ostream &operator<<(ostream &o, const pair<T1, T2> &p)
{
return o << "<" << p.first << ", " << p.second << ">";
}

template <typename T>
typename enable_if<
!is_same<T, string>::value &&
is_same<decltype(begin(declval<T>())), decltype(end(declval<T>()))>::value,
ostream&
>::type
operator<<(ostream &o, const T &v)
{
o << "{";
for (auto it = begin(v); it != end(v); ++it)
o << (it == begin(v) ? "" : " ,") << *it;
return o << "}";
}
*/

int main()
{
// 关闭同步流 加快IO
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

return 0;
}

一些常量

类型 范围 常量名
int INT_MAX INT_MIN UINT_MAX
long long LLONG_MAX LLONG_MIN ULLONG_MAX
float $[-3.4 \times 10^{38}, 3.4 \times 10^{38}]$ FLT_MAX FLT_MIN DBL_MAX DBL_MIN
char $[0, 255]$ CHAR_MAX CHAR_MIN UCHAR_MAX
double $[-1.7 \times 10^{-308}, 1.7 \times 10^{308}]$ DBL_MAX DBL_MIN
long double $[-1.1 \times 10^{-4932}, 1.1 \times 10^{4932}]$ LDBL_MAX LDBL_MIN

数据结构模板

大数高精度模拟

处理大数 $10^{1000}$ 级别的加减乘除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
struct BigInte
{
vector<int> d;
int sign;

BigInte(ll num=0){*this=num;}
BigInte(const string s){*this=s;}

BigInte& operator=(ll num)
{
d.clear(); sign=1;
if(num<0) sign=-1, num=-num;
if(num==0) d.push_back(0);
while(num)
{
d.push_back(num%10);
num/=10;
}
return *this;
}

BigInte& operator=(const string &s)
{
d.clear(); sign=1;
int start=0;
if(s[0]=='-') sign=-1, start=1;
for(int i=s.size()-1; i>=start; i--)
if(isdigit(s[i])) d.push_back(s[i]-'0');
trim(); // 去除前导0
return *this;
}

void trim()
{
while(!d.empty() && !d.back()) d.pop_back();
if(d.size()==1 && d[0]==0) sign=1;
}

string str() const
{
string s=(sign==-1? "-":"");
for(int i=(int)d.size()-1; i>=0; i--)s+=char('0'+d[i]);
return s;
}

bool absLess(const BigInte &b) const
{
if(d.size()!=b.d.size()) return d.size()<b.d.size();
for(size_t i=d.size()-1; i>=0; i--)
if(d[i]!=b.d[i]) return d[i]<b.d[i];
return false;
}

bool operator<(const BigInte &b) const
{
if(sign!=b.sign) return sign<b.sign;
if(sign==1)return absLess(b);
else return b.absLess(*this);
}

bool operator==(const BigInte &b) const
{
return sign==b.sign && d==b.d;
}

BigInte operator+(const BigInte &b) const
{
if(sign==b.sign)
{
BigInte c;
c.sign=sign;
c.d.resize(max(d.size(), b.d.size())+1,0);
int carry=0;
for(size_t i=0; i<c.d.size(); i++)
{
int x=carry;
if(i<d.size()) x+=d[i];
if(i<b.d.size()) x+=b.d[i];
c.d[i]=x%10;
carry=x/10;
}
c.trim();
return c;
}
return *this-(-b);
}

BigInte operator-() const
{
BigInte c=*this;
if(!(d.size()==1 && d[0]==0)) c.sign=-c.sign;
return c;
}

BigInte operator-(const BigInte &b) const
{
if(sign!=b.sign) return *this+(-b);
if((sign==1 && *this<b) || (sign==-1 && b<*this)) return -(b-*this);

BigInte c;
c.sign=sign;
c.d.resize(d.size(),0);
int borrow=0;
for(size_t i=0; i<d.size(); i++)
{
int x=d[i]-borrow;
if(i<b.d.size()) x-=b.d[i];
if(x<0) x+=10, borrow=1;
else borrow=0;
c.d[i]=x;
}
c.trim();
return c;
}

BigInte operator*(const BigInte &b) const
{
BigInte c;
c.sign=sign*b.sign;
c.d.assign(d.size()+b.d.size(),0);
for(size_t i=0; i<d.size(); i++)
{
int carry=0;
for(size_t j=0; j<b.d.size() || carry; j++)
{
long long cur=c.d[i+j]+(long long)d[i]*(j<b.d.size()? b.d[j]:0)+carry;
c.d[i+j]=cur%10;
carry=cur/10;
}
}
c.trim();
return c;
}

// 向下取整
BigInte operator/(const BigInte &b) const
{
BigInte a=*this, div=b;
a.sign=div.sign=1;
if(a.absLess(div)) return BigInte(0);

BigInte cur=0, res;
res.d.resize(d.size());
for(int i=(int)d.size()-1; i>=0; i--)
{
cur.d.insert(cur.d.begin(),d[i]);
cur.trim();
int x=0, l=0, r=9;
while(l<=r)
{
int m=(l+r)/2;
BigInte t=div*m;
if(!cur.absLess(t)) x=m, l=m+1;
else r=m-1;
}
res.d[i]=x;
cur=cur-div*x;
}
res.sign=sign*b.sign;
res.trim();
return res;
}
};
/*
//例子
BigInte a("1234987329857423985794783259");
BigInte b("124098321759817239843279812374");
cout << (a+b).str() << endl;
cout << (a*b).str() << endl;
*/

并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct DSU
{
int n;
vector<int> fa, sz;

DSU(int _n)
{
n=_n;
fa.resize(n);
sz.resize(n, 1);
for(int i=0; i<n; i++) fa[i]=i;
}

int find(int x)
{
if(fa[x]!=x) fa[x]=find(fa[x]);
return fa[x];
}

void unite(int x, int y)
{
int ra=find(x), rb=find(y);
if(ra!=rb)
{
if(sz[ra]<sz[rb]) swap(ra, rb);
fa[rb]=ra;
sz[ra]+=sz[rb];
}
}
};

字符串哈希

需要频繁子串比较 / 多次查询:用哈希更方便,尤其是在线性区间查询时几乎必选。

  • 双值hash
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
using ull = unsigned long long;
ull base = 131;
ull mod1 = 212370440130137957, mod2 = 1e9 + 7;

ull get_hash1(std::string s) {
int len = s.size();
ull ans = 0;
for (int i = 0; i < len; i++) ans = (ans * base + (ull)s[i]) % mod1;
return ans;
}

ull get_hash2(std::string s) {
int len = s.size();
ull ans = 0;
for (int i = 0; i < len; i++) ans = (ans * base + (ull)s[i]) % mod2;
return ans;
}

bool cmp(const std::string s, const std::string t) {
bool f1 = get_hash1(s) != get_hash1(t);
bool f2 = get_hash2(s) != get_hash2(t);
return f1 || f2;
}
  • 子串hash匹配
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct StringHash
{
using ull = unsigned long long;
static const ull base = 131;
vector<ull> h, p; // h 前缀哈希, p 幂次表

StringHash(const string &s)
{
int n = s.size();
h.assign(n + 1, 0);
p.assign(n + 1, 1);
for (int i = 1; i <= n; i++)
{
h[i] = h[i-1] * base + (s[i-1] - 'a' + 1);
p[i] = p[i-1] * base;
}
}

// 查询子串 [l,r] 的哈希 (1-indexed)
ull get(int l, int r)
{
return h[r] - h[l-1] * p[r-l+1];
}
};

使用例子,判断字字串符串是否相等 :

  • 这里先预处理字符串的hash前缀
  • 然后再计算子串s[l,l+1,,,r-1,r]的hash,并对比

二叉堆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
struct BinaryHeap // 大根堆
{
vector<int> heap;

void push(int x)
{
heap.push_back(x);
int i=heap.size()-1;
while(i>1&&heap[i]>heap[i/2])
{
swap(heap[i],heap[i/2]);
i/=2;
}
}

void pop()
{
int n=heap.size();
heap[1]=heap[n-1];
heap.pop_back();
int i=1;
while(1)
{
int largest=i;
int l=i*2, r=i*2+1;
if(l<n && heap[l]>heap[largest]) largest=l;
if(r<n && heap[r]>heap[largest]) largest=r;
if(largest==i) break;
swap(heap[i],heap[largest]);
i=largest;
}
}

int top()
{
return heap[1];
}

void buildHeap(vector<int> &arr)
{
for(int i=arr.size()/2; i>=1; i--)
{
int j=i;
while(1)
{
int largest=j;
int l=j*2, r=j*2+1;
if(l<arr.size() && arr[l]>arr[largest]) largest=l;
if(r<arr.size() && arr[r]>arr[largest]) largest=r;
if(largest==j) break;
swap(arr[j],arr[largest]);
j=largest;
}
}
}
}
/*
用法:
BinaryHeap heap;
heap.push(1);
heap.push(2);
heap.push(3);
heap.pop();
cout << heap.top() << "\n";
vector<int> arr={3,212,33,44,15};
BinaryHeap heapp;
heapp.buildHeap(arr);
*/

最小生成树(MST)

  • Prim算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <bits/stdc++.h>
using namespace std;
#define tup tuple<int,int,int>

struct node
{
int u,v,w;
};

vector<node> Prim(vector<vector<int>>& g, vector<vector<int>> &weight, int st)
{
vector<node> MST;
vector<bool> vis(g.size(), false);
priority_queue<tup, vector<tup>, greater<tup>> pq;
vis[st] = true;
for (auto &v: g[st])pq.push({weight[st][v], st, v});
while (!pq.empty() && (int)MST.size() < (int)g.size() - 1)
{
auto [w, u, v] = pq.top(); pq.pop();
if (vis[v]) continue;
vis[v] = true;
MST.push_back({u, v, w});
for (auto nxt: g[v])
{
if (!vis[nxt]) pq.push({weight[v][nxt], v, nxt});
}
}
return MST;
}

树状数组

支持单点更新add(x, val), 前缀查询sum(x), 区间查询rangeSum(l, r)

  • 一维树状数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
const int N=200005;
ll bit[N];
int n;

inline int lowbit(int x){return x&(-x);}

inline void add(int x, ll v)
{
for(; x<=n; x+=lowbit(x)) bit[x]+=v;
}

inline ll sum(int x)
{
ll s=0;
for(; x>0; x-=lowbit(x)) s+=bit[x];
return s;
}

// [l,r] 区间
inline ll rangeSum(int l, int r)
{
return sum(r)-sum(l-1);
}

void build(vector<int> &a)
{
for(int i=1; i<=n; i++) add(i,a[i]);
}
/*
用法:
vector<int> a={0,1,2,3,4,5};
n=a.size();
build(a);
add(pos, delta);
cout << rangeSum(l, r) << "\n";
*/
  • 二维树状数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
const int N=2005;
ll bit2[N][N];
int n, m;

inline int lowbit(int x){return x&(-x);}

inline void add(int x, int y, ll v)
{
for(int i=x; i<=n; i+=lowbit(i))
for(int j=y; j<=m; j+=lowbit(j))
bit2[i][j]+=v;
}

inline ll sum(int x, int y)
{
ll s=0;
for(int i=x; i>0; i-=lowbit(i))
for(int j=y; j>0; j-=lowbit(j))
s+=bit2[i][j];
return s;
}

// [(x1,y1),(x2,y2)] 矩形
inline ll rangeSum(int x1, int y1, int x2, int y2)
{
return sum(x2,y2)-sum(x1-1,y2)-sum(x2,y1-1)+sum(x1-1,y1-1);
}
  • 支持单点加/区间加,区间和, 前缀第k小的一维树状数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
const int N=200005;
ll bit1[N], bit2[N];
int n;

inline int lowbit(int x){return x&(-x);}

inline void add1(int x, ll v)
{
for(; x<=n; x+=lowbit(x)) bit1[x]+=v;
}

inline ll sum1(int x)
{
ll s=0;
for(; x>0; x-=lowbit(x)) s+=bit1[x];
return s;
}

// 未加的区间和
inline ll rangeSum1(int l, int r)
{
return sum1(r)-sum1(l-1);
}

// 区间加--单点查
inline void rangeAdd1(int l, int r, ll v)
{
add1(l,v);
add1(r+1,-v);
}

inline ll pointQuery(int x)
{
return sum1(x);
}

// 区间加--区间和

inline void add2(ll *bit, int x, ll v)
{
for(; x<=n; x+=lowbit(x)) bit[x]+=v;
}

inline void rangeAdd2(int l, int r, ll v)
{
add2(bit1,l,v);
add2(bit1,r+1,-v);
add2(bit2,l,v*(l-1));
add2(bit2,r+1,-v*r);
}

inline ll sum2(int x)
{
ll s1=0, s2=0, t=x;
for(; x>0; x-=lowbit(x))
{
s1+=bit1[x];
s2+=bit2[x];
}
return s1*t-s2;
}

inline ll rangeSum2(int l, int r)
{
return sum2(r)-sum2(l-1);
}

/*
使用这个kth时
应该声明使用bit来表示频率数组,如
memset(bit,0,sizeof(bit));
n=5;
add(1,2); // 2个1
add(2,3); // 3个2
*/
inline ll kth(int k)
{
int pos=0;
for(int i=1<<20; i; i>>=1)
{
if(pos+i<=n && bit1[pos+i]<k)
{
k-=bit1[pos+i];
pos+=i;
}
}
return pos+1;
}

算法模板

二分查找

注意数组必须有序!!!

1
2
3
4
5
6
7
8
9
10
11
12
bool check(ll x)
{
// 判断x是否满足条件
}

ll l=-1, r=LLONG_MAX;
while(l+1!=r)
{
ll m=(l+r)>>1;
if(check(m)) l=m;
else r=m;
}

数学

快速幂

  • 不带模的快速幂

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    // a^b
    ll qpow(ll a, ll b)
    {
    ll res=1;
    while(b)
    {
    if(b&1) res=res*a;
    a=a*a;
    b>>=1;
    }
    return res;
    }
  • 带模的快速幂

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    // a^b % mod
    ll qpowm(ll a, ll b, ll mod)
    {
    ll res=1%mod;
    a%=mod;
    while(b)
    {
    if(b&1) res=res*a%mod;
    a=a*a%mod;
    b>>=1;
    }
    return res;
    }

矩阵快速幂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
const int mod=1e9+7;
vector<vector<ll>> matrix_mul(vector<vector<ll>> &a,vector<vector<ll>> &b)
{
int n=a.size(), p=b.size(), m=b[0].size();
vector<vector<ll>> c(n,vector<ll>(m,0));
for(int i=0; i<n; i++)
{
for(int k=0; k<p; k++)
{
if(a[i][k]==0) continue;
for(int j=0; j<m; j++)
c[i][j]=(c[i][j]+1LL*a[i][k]*b[k][j])%mod;
}
}
return c;
}

vector<vector<ll>> qpow(vector<vector<ll>> &a, ll p)
{
int n=a.size();
vector<vector<ll>> res(n,vector<ll>(n,0));
for(int i=0; i<n; i++) res[i][i]=1;
while(p)
{
if(p&1) res=matrix_mul(res,a);
a=matrix_mul(a,a);
p>>=1;
}
return res;
}

扩展欧几里得

1
2
3
4
5
6
7
8
9
10
11
pii exgcd(ll a, ll b, ll c)
{
if(b==0)
{
if(a==0)return {0,0};
if(c%a!=0)return {0,0};
return {c/a,0};
}
auto [x1,y1]=exgcd(b,a%b,c);
return {y1,x1-(a/b)*y1};
}

乘法逆元

  • 扩展欧几里得求逆元
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0)
{
x = 1;
y = 0;
return a;
}
ll g = exgcd(b, a % b, y, x);
y -= a / b * x;
return g;
}

ll mod_inverse(ll a, ll m)
{
ll x, y;
ll g = exgcd(a, m, x, y);
if(g != 1) return -1;
return (x % m + m) % m;
}
  • 费马小定理求逆元(只适用于模数为质数)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

ll mod_inverse(ll a, ll m)
{
return qpow(a, m - 2, m);
}
  • 线性求逆元(求1到n所有数的逆元)
1
2
3
4
5
6
7
8
vector<ll> line_inv(int n, ll mod)
{
vector<ll> inv(n+1);
inv[1]=1;
for(int i=2; i<=n; i++)
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
return inv;
}

中国剩余定理

  • 模数两两互质
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0)
{
x=1, y=0;
return a;
}
ll g=exgcd(b,a%b,y,x);
y-=a/b*x;
return g;
}

ll CRT(vector<ll> &a, vector<ll> &m)
{
ll M=1, res=0;
for(auto &mi: m) M*=mi;
for(int i=0; i<a.size(); i++)
{
ll Mi=M/m[i], x, y;
exgcd(Mi,m[i],x,y);
x=(x%M+M)%M;
res=(res+a[i]*Mi%M*x%M)%M;
}
return res;
}
  • 模数不两两互质
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0)
{
x=1, y=0;
return a;
}
ll g=exgcd(b,a%b,y,x);
y-=a/b*x;
return g;
}

ll exCRT(vector<ll> &a, vector<ll> &m)
{
ll x=a[0], M=m[0];
for(int i=1; i<a.size(); i++)
{
ll a2=a[i], m2=m[i];
ll c=(a2-x%m2+m2)%m2, g=__gcd(M,m2);
if(c%g!=0) return -1;
// solve k*M=c(mod m2)
ll k, t;
exgcd(M,m2,k,t);
k=k*(c/g)%(m2/g);
k=(k%(m2/g)+(m2/g))%(m2/g);
x=x+k*M;
M=M/g*m2;
x%=M;
}
}

线筛质数、欧拉、莫比乌斯

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
const int N = 1e7+5;
const int MOD = 1000000007;
vector<ll> pri;
bool not_prime[N];
ll phi[N],mu[N];


void pre(ll n)
{
phi[1]=1;
mu[1]=1;
for(ll i=2; i<=n; ++i)
{
if(!not_prime[i])
{
pri.push_back(i);
phi[i]=i-1;
mu[i]=-1;
}
for(ll pri_j : pri)
{
if (i*pri_j>n) break;
not_prime[i*pri_j]=true;
if (i%pri_j == 0)
{
phi[i*pri_j]=phi[i]*pri_j;
mu[i*pri_j]=0;
break;
}
phi[i*pri_j]=phi[i]*phi[pri_j];
mu[i*pri_j]=-mu[i];
}
}
}

快速组合数

  • 预处理阶乘和阶乘逆元(适用于多次查询)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
const ll MOD = 1e9+7;
const int MAXN = 1e6+5;
vector<ll> fac(MAXN), ifac(MAXN);

ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

void pre(int n)
{
fac[0]=1;
for(int i=1; i<=n; i++) fac[i]=(fac[i-1]*i)%MOD;
ifac[n]=qpow(fac[n], MOD-2, MOD);
for(int i=n-1; i>=0; i--) ifac[i]=(ifac[i+1]*(i+1))%MOD;
}

// C_n^m
ll C(int n, int m)
{
if(n<0 || n>m) return 0;
return fac[n]*ifac[m]%MOD*ifac[n-m]%MOD;
}
  • 卢卡斯定理(适用于大数组合数模小质数)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

ll small(ll m, ll n, ll p)
{
if(n<0 || n>m) return 0;
if(n==0 || n==m) return 1;
ll a=1, b=1;
for(int i=1; i<=n; i++)
{
a=a*(m-i+1)%p;
b=(b*i)%p;
}
return a*qpow(b, p-2, p)%p;
}

ll Lucas(ll m, ll n, ll p)
{
if(n==0) return 1;
return small(m%p, n%p, p)*Lucas(m/p, n/p, p)%p;
}

凸包算法

解决二维平面上的点集凸包问题,即求出包含所有点的最小凸多边形

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
struct Point
{
double x, y;
Point(double x=0, double y=0):x(x), y(y) {}
Point operator-(const Point &p) const {return Point(x-p.x, y-p.y);}
bool operator<(const Point &p) const {return x<p.x || (x==p.x && y<p.y);}
};

double cross(const Point &a, const Point &b)
{
return a.x*b.y-a.y*b.x;
}

// 两点距离平方
double disSq(const Point &a, const Point &b)
{
double dx=a.x-b.x, dy=a.y-b.y;
return dx*dx+dy*dy;
}

Point pivot; // 凸包基点
bool cmp(const Point &a, const Point &b)
{
double c=cross(a-pivot, b-pivot);
if(c==0) return disSq(a, pivot)<disSq(b, pivot);
return c>0;
}

vector<Point> convexHull(vector<Point> &p)
{
int n=p.size(), k=0;
if(n<=3) return p;
int minn=0;
for(int i=1; i<n; i++) if(p[i]<p[minn]) minn=i;
swap(p[0], p[minn]);
pivot=p[0];
sort(p.begin()+1, p.end(), cmp);
stack<Point> s;
s.push(p[0]), s.push(p[1]), s.push(p[2]);
for(int i=3; i<n; i++)
{
Point top=s.top(); s.pop();
while(cross(top-s.top(), p[i]-s.top())<=0)
{
top=s.top(); s.pop();
}
s.push(top); s.push(p[i]);
}
vector<Point> res;
while(!s.empty())
{
res.push_back(s.top());
s.pop();
}
reverse(res.begin(), res.end());
return res;
}

快速傅里叶变换(FFT)

用于多项式乘法 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
typedef complex<double> cd;
const double PI = acos(-1.0);
void FFT(vector<cd>& a, bool invert)
{
int n=a.size();
for(int i=1, j=0; i<n; i++)
{
int bit=n>>1;
for(; j&bit; bit>>=1) j^=bit;
j^=bit;
if(i<j) swap(a[i], a[j]);
}
for(int len=2; len<=n; len<<=1)
{
double ang=2*PI/len*(invert?-1:1);
cd wlen(cos(ang), sin(ang));
for(int i=0; i<n; i+=len)
{
cd w(1);
for(int j=0; j<len/2; j++)
{
cd u=a[i+j], v=w*a[i+j+len/2];
a[i+j]=u+v;
a[i+j+len/2]=u-v;
w*=wlen;
}
}
}
if(invert)
{
for(cd &x: a) x/=n;
}
}

// 多项式乘法
vector<int> multi(const vector<int> &a, const vector<int> &b)
{
vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
int n=1;
while(n<a.size()+b.size()) n<<=1;
fa.resize(n), fb.resize(n);
FFT(fa, false), FFT(fb, false);
for(int i=0; i<n; i++) fa[i]*=fb[i];
FFT(fa, true);
vector<int> res(n);
for(int i=0; i<n; i++) res[i]=round(fa[i].real());
// 如果是大数乘法,需要处理进位
// int carry=0;
// for(int i=0; i<n; i++)
// {
// res[i]+=carry;
// carry=res[i]/10;
// res[i]%=10;
// }
return res;
}

数论变换(NTT)

对于需要在模意义下计算的情况,则需要NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const int MOD = 998244353;
const int ROOT = 3;

ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

void NTT(vector<ll>& a, bool invert)
{
int n=a.size();
for(int i=1, j=0; i<n; i++)
{
int bit=n>>1;
for(; j&bit; bit>>=1) j^=bit;
j^=bit;
if(i<j) swap(a[i], a[j]);
}
for(int len=2; len<=n; len<<=1)
{
int wlen=qpow(ROOT, (MOD-1)/len, MOD);
if(invert) wlen=qpow(wlen, MOD-2, MOD);
for(int i=0; i<n; i+=len)
{
ll w=1;
for(int j=0; j<len/2; j++)
{
int u=a[i+j], v=w*a[i+j+len/2]%MOD;
a[i+j]=(u+v)%MOD;
a[i+j+len/2]=(u-v+MOD)%MOD;
w=(1LL*w*wlen)%MOD;
}
}
}
if(invert)
{
int nInv=qpow(n, MOD-2, MOD);
for(int &x: a) x=(1LL*x*nInv)%MOD;
}
}

拓扑排序

用于处理有向无环图(DAG)的排序问题,如:比赛排名、任务调度等

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
vector<int> topoSort(vector<vector<int> > &G)//G为邻接表
{
vector<int> topo;
vector<int> in(G.size(), 0); //记录每个节点的入度
stack<int> s; // 也可以使用priority_queue
for(int u=0; u<G.size(); u++)for(auto &v: G[u])in[v]++;
for(int u=0; u<G.size(); u++)if(in[u]==0)s.push(u);
while(!s.empty())
{
int u=s.top(); s.pop();
topo.push_back(u);
for(auto &v: G[u])
{
in[v]--;
if(in[v]==0)s.push(v);
}
}
return topo;
}

Dijkstra算法

用于求解单源最短路距离问题(非负权图),时间复杂度 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#define ll long long 
#define pii pair<int, int> // (dist, vertex)
const ll INF=LLONG_MAX;

vector<ll> dijkstra(const vector<vector<pii> > &G, int s)
{
ll n=G.size();
vector<ll> dist(n, INF);
priority_queue<pii, vector<pii>, greater<pii> > pq;
dist[s]=0;
pq.push({0, s});
while(!pq.empty())
{
auto [u, d]=pq.top(); pq.pop();
if(d>dist[u]) continue;
for(auto [v, w]: G[u])if(dist[u]+w<dist[v])
{
dist[v]=dist[u]+w;
pq.push({v, dist[v]});
}
}
return dist;
}

Bellman-Ford算法

适用于可能包含负权边的有向图或无向图的最短路问题,时间复杂度 $O(nm)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#define ll long long
const ll INF=LLONG_MAX;

struct Edge
{
int u, v, w;
Edge(int _u, int _v, int _w): u(_u), v(_v), w(_w) {}
};

/*
* @param s: 源点
* @param end: 终点
* @param E: 边集
* @param n: 点数
* @param dist: 最短路距离
* @param path: 最短路径
* @return: 是否存在负环
*/
bool bellmanFord(ll s, ll end, const vector<Edge> &E, ll n , vector<ll> &dist, vector<ll> &path)
{
vector<ll> pre(n, -1);
dist.assign(n, INF);
dist[s]=0;
for(ll i=0; i<n-1; i++)
{
bool flag=false;
for(const auto &e: E)
{
if(dist[e.u]!=INF && dist[e.u]+e.w<dist[e.v])
{
dist[e.v]=dist[e.u]+e.w;
pre[e.v]=e.u;
flag=true;
}
}
if(!flag) break;
}
// 判断是否存在负环
for(const auto &e: E)
if(dist[e.u]!=INF && dist[e.u]+e.w<dist[e.v]) return true;
if(dist[end]==INF) return false; // 不存在最短路

// 求最短路径
path.clear();
for(ll i=end; i!=-1; i=pre[i]) path.push_back(i);
reverse(path.begin(), path.end());

return false;
}

SPFA算法

该算法是Bellman-Ford算法的优化版本,时间复杂度 $O(km)$,其中 $k$ 是常数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#define ll long long 
const ll INF=LLONG_MAX;

struct Edge
{
int to, w;
Edge(int _to, int _w): to(_to), w(_w) {}
};

bool spfa(ll s, ll end, ll n, const vector<vector<Edge> > &G, vector<ll> &dist, vector<ll> &path)
{
vector<ll> pre(n, -1);
vector<int> cnt(n, 0);
vector<bool> inq(n, false);
dist.assign(n, INF);
dist[s]=0;
queue<int> q;
q.push(s);
inq[s]=true, cnt[s]=1;
while(!q.empty())
{
int u=q.front(); q.pop();
inq[u]=false;
for(const auto [v, w]: G[u])
{
if(dist[u]!=INF && dist[u]+w<dist[v])
{
dist[v]=dist[u]+w;
pre[v]=u;
if(!inq[v])
{
q.push(v);
inq[v]=true, cnt[v]++;
if(cnt[v]>n) return true; // 存在负环
}
}
}
}
if(dist[end]==INF) return false; // 不存在最短路

// 求最短路径
path.clear();
for(ll i=end; i!=-1; i=pre[i]) path.push_back(i);
reverse(path.begin(), path.end());

return false;
}

Floyd-Warshall 算法

用于求解所有顶点对之间的最短路径,适用于有向图或无向图,可以处理负权边(但不能有负权回路)。 时间复杂度 $O(n^3)$ , 仅适用于 $n \le 500$ 的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
const int INF=0x3f3f3f3f;
const int MAXN=505;

int n, dis[MAXN][MAXN];

void floyd()
{
for(int k=1; k<=n; k++)
{
for(int i=1; i<=n; i++)
{
for(int j=1; j<=n; j++) if(dis[i][k]!=INF && dis[k][j]!=INF)
dis[i][j]=min(dis[i][j], dis[i][k]+dis[k][j]);
}
}
}
// 初始化
void init()
{
for(int i=1; i<=n; i++)
{
for(int j=1; j<=n; j++)
{
if(i==j) dis[i][j]=0;
else dis[i][j]=INF;
}
}
}
// 检测负权回路
bool hasNegativeCycle()
{
for(int k=1; k<=n; k++)
if(dis[k][k]<0) return true;
return false;
}
  • 如果需要重建最短路径,可以额外维护一个next矩阵,记录从ij的最短路径上i的后继节点。

Kosaraju算法

用于求解有向图的强连通分量(SCC)问题,如2-SAT问题,时间复杂度 $O(V+E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const int MAXN=100005;

vector<int> G[MAXN], G_[MAXN], comp_nodes[MAXN];
vector<int> order, comp(MAXN);
bool vis[MAXN];
int scc;

void dfs1(int u)
{
vis[u]=true;
for(auto v: G[u]) if(!vis[v]) dfs1(v);
order.push_back(u);
}

void dfs2(int u, int c)
{
comp[u]=c;
comp_nodes[c].push_back(u);
vis[u]=true;
for(auto v: G_[u]) if(!vis[v]) dfs2(v, c);
}

void kosaraju(int n)
{
order.clear();
scc=0;
memset(vis, false, sizeof(vis));
memset(comp, -1, sizeof(comp));
for(int i=1; i<=n; i++) if(!vis[i]) dfs1(i);
memset(vis, false, sizeof(vis));
for(int i=(int)order.size()-1; i>=0; i--) if(!vis[order[i]])
dfs2(order[i], ++scc);
}

vector<int> DAG[MAXN];
void buildDAG(int n)
{
for(int u=1; u<=n; u++)
{
for(auto v: G[i])if(comp[u]!=comp[v])
DAG[comp[u]].push_back(comp[v]);
}
// 去重边(可选)
for(int i=1; i<=scc; i++)
{
sort(DAG[i].begin(), DAG[i].end());
DAG[i].erase(unique(DAG[i].begin(), DAG[i].end()), DAG[i].end());
}
}

Tarjan算法

不仅可用于求强连通分量,还可用于求割点、桥等图论问题。时间复杂度 $O(V+E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
const int N=100005;
vector<int> G[N];
int dfn[N], low[N], comp[N], inSt[N];
stack<int> st;
int dfsti=0, scc=0;

/*
* @param u 当前节点
*/
void tarjan(int u)
{
dfsn[u]=low[u]=++dfsti;
st.push(u); inSt[u]=1;
for(int v: G[u])
{
if(!dfsn[v])
{
tarjan(v);
low[u]=min(low[u], low[v]);
}
else if(inSt[v]) low[u]=min(low[u], dfn[v]);
}
if(dfsn[u]==low[u])
{
scc++;
while(1)
{
int x=st.top(); st.pop();
inSt[x]=0, comp[x]=scc;
if(x==u) break;
}
}
}
/*
使用方法:
for(int i=1; i<=n; i++) if(!dfsn[i]) tarjan(i); // 遍历所有节点求解强连通分量
*/

Kuhn-Munkres算法

匈牙利算法用于求解二分图的最大匹配问题,时间复杂度 $O(nm)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
const int MAXN=505;
int n,m; // 左右节点数
vector<int> G[MAXN];
vector<int> match; // match[v] 为与右部 v 匹配的左部点
vector<bool> vis;

bool dfs(int u)
{
for(int v: G[u])if(!vis[v])
{
vis[v]=true;
if(match[v]==-1 || dfs(match[v]))
{
match[v]=u;
return true;
}
}
return false;
}

int kuhn(int n)
{
int res=0;
match.assign(m+1, -1);
for(int i=1; i<=n; i++)
{
vis.assign(m+1, false);
if(dfs(i)) res++;
}
return res;
}

Hopcroft-Karp算法

用于求解大数据集下的二分图最大匹配问题,时间复杂度 $O(m\sqrt{n})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
const int MAXN=505;
vector<int> G[MAXN];
int matchL[MAXN], matchR[MAXN], dist[MAXN];
int n,m; // 左右节点数

bool bfs()
{
queue<int> q;
for(int u=1; u<=n; u++)
{
if(matchL[u]==-1)
{
dist[u]=0;
q.push(u);
}
else dist[u]=-1;
}
bool f=false;
while(!q.empty())
{
int u=q.front(); q.pop();
for(int v: G[u])
{
if(matchR[v]==-1) f=true;
else if(dist[matchR[v]]==-1)
{
dist[matchR[v]]=dist[u]+1;
q.push(matchR[v]);
}
}
}
return f;
}

bool dfs(int u)
{
for(int v: G[u])
{
if(matchR[v]==-1 || (dist[matchR[v]]==dist[u]+1 && dfs(matchR[v])))
{
matchL[u]=v, matchR[v]=u;
return true;
}
}
dist[u]=-1;
return false;
}

int hopcroft_karp()
{
fill(matchL, matchL+n+1, -1);
fill(matchR, matchR+m+1, -1);
int res=0;
while(bfs())
{
for(int u=1; u<=n; u++) if(matchL[u]==-1) res+=dfs(u);
}
return res;
}

树的直径

树的直径是指树中任意两点间的最长路径。这里使用树形DP求解 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
const int MAXN=100005;
vector<int> tree[MAXN];
int dia=0;
int dfs(int u, int fa)
{
int max1=0, max2=0;
for(int v: tree[u])
{
if(v==fa) continue;
int d=dfs(v, u)+1;
if(d>max1) max2=max1, max1=d;
else if(d>max2) max2=d;
}
dia=max(dia, max1+max2);
return max1;
}

int treeDia(int n)
{
dia=0;
dfs(1, 0);
return dia;
}

倍增LCA

通过倍增预处理节点的 $2^k$ 个祖先,使得可以在时间复杂度 $O(\log n)$ 找到任意两个节点的最近公共祖先。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
const int MAXN=100005;
const int LOG=17; // log2(n)的向上取整

vector<int> tree[MAXN];
int dep[MAXN], fa[MAXN][LOG];

// BFS初始化深度和直接父节点
void bfs(int root, int n)
{
queue<int> q;
q.push(root);
dep[root]=1, fa[root][0]=-1;
while(!q.empty())
{
int u=q.front(); q.pop();
for(int v: tree[u])
{
if(v==fa[u][0]) continue;
dep[v]=dep[u]+1;
fa[v][0]=u;
q.push(v);
}
}
}


// 预处理倍增表
void init(int n)
{
bfs(1, n); // 根节点为1
for(int j=1; j<LOG; j++)
{
for(int i=1; i<=n; i++)
{
if(fa[i][j-1]==-1) fa[i][j]=-1;
else fa[i][j]=fa[fa[i][j-1]][j-1];
}
}
}

// 将节点u向上移动k步
int liftup(int u, int k)
{
for(int j=0; j<LOG; j++)
{
if(k&(1<<j))
{
u=fa[u][j];
if(u==-1) break;
}
}
}

// 找到u和v的最近公共祖先
int lca(int u, int v)
{
if(dep[u]<dep[v]) swap(u, v);
u=liftup(u, dep[u]-dep[v]);
if(u==v) return u;
for(int j=LOG-1; j>=0; j--)
if(fa[u][j]!=fa[v][j])
u=fa[u][j], v=fa[v][j];
return fa[u][0];
}

// 求两点间的距离
int dis(int u, int v)
{
int w=lca(u, v);
return dep[u]+dep[v]-2*dep[w];
}

// 使用之前记得init()

Tarjan算法(离线)

Tarjan算法(离线)可以高效地解决LCA的问题 $O(n+\alpha(n))$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
struct DSU
{
int n;
vector<int> fa, sz;

DSU(int _n)
{
n=_n;
fa.resize(n);
sz.resize(n, 1);
for(int i=0; i<n; i++) fa[i]=i;
}

int find(int x)
{
if(fa[x]!=x) fa[x]=find(fa[x]);
return fa[x];
}

void unite(int x, int y)
{
int ra=find(x), rb=find(y);
if(ra!=rb)
{
if(sz[ra]<sz[rb]) swap(ra, rb);
fa[rb]=ra;
sz[ra]+=sz[rb];
}
}
};

vector<int> tree[MAXN];
vector<pii> query[MAXN]; // query[u]=(v,query_id)
vector<int> ancestor, ans;
vector<bool> vis;

void tarjan(int u, int fa, DSU& dsu)
{
ancestor[u]=u;
for(int v: tree[u])
{
if(v==fa) continue;
tarjan(v, u, dsu);
dsu.unite(u, v);
ancestor[dsu.find(u)]=u;
}
vis[u]=true;
for(auto &[v, idx]: query[u]) if(vis[v])
ans[idx]=ancestor[dsu.find(v)];
}
/*
使用前记得初始化
DSU dsu(n+1);
ancestor.resize(n+1);
vis.resize(n+1,false);
ans.resize(t); // t个查询
//然后在query中加入查询
for(int i=0; i<m; i++)
query[u].push_back({v, i});
query[v].push_back({u, i});
//然后tarjan(1,0,dsu);
*/

字符串

STL的容器本身提供了一些字符串的算法:

  • unordered_map的原理是哈希表,可以用来快速查找字符串。
  • stringsubstr((int/size_t)pos, len) 函数可以用来截取子串。
  • stringfind()函数可以用来查找子串。找到,返回子串在原串中的起始位置(下标size_t), 否则返回string::npos
  • stringreplace(pos, len, str)函数可以用来替换子串,从pos开始长度为len的子串被替换为str
  • reverse()函数可以用来反转字符串。

KMP

KMP用于快速匹配字符串,时间复杂度 $O(n+m)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
vector<int> KMP(const string& s, const string& p)
{
// next array
int n=s.size(), m=p.size();
vector<int> nxt(m);
for(int i=1, j=0; i<m; i++)
{
while(j>0 && p[i]!=p[j]) j=nxt[j-1];
if(p[i]==p[j]) j++;
nxt[i]=j;
}
// search
vector<int> res;
for(int i=0, j=0; i<n; i++)
{
while(j>0 && s[i]!=p[j]) j=nxt[j-1];
if(s[i]==p[j]) j++;
}
}
/*
How to ues:
string s="abcabcaaabcaabccbabca", p="abc";
auto pos=KMP(s, p);
for(auto i: pos) cout<<i<<' '; // output initial the match positions
*/

Z-Algorithm

Z-Algorithm用于快速找到字符串的所有子串与其自身匹配的位置,时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
vector<int> Z_Algo(const string& s)
{
int n=s.size();
vector<int> z(n, 0);
int l=0, r=0;
for(int i=1; i<n; i++)
{
if(i<=r) z[i]=min(r-i+1, z[i-l]);
while(i+z[i]<n && s[z[i]]==s[i+z[i]]) z[i]++;
if(i+z[i]-1>r) l=i, r=i+z[i]-1;
}
return z;
}

vector<int> find(const string& s, const string& p)
{
string t=p+'#'+s;
vector<int> z=Z_Algo(t);
vector<int> res;
int m=p.size();
for(int i=m+1; i<(int)t.size(); i++)
if(z[i]>=m) res.push_back(i-m-1);
return res;
}
/*
How to ues:
string s="abcabcaaabcaabccbabca", p="abc";
auto pos=find(s, p); // output initial the match positions
}

Trie树(字典树)

Trie树用于快速插入、查找、删除、前缀匹配单词,时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
struct Trie
{
Trie* ch[26];
bool isEnd;
int cnt;
Trie(): isEnd(false), cnt(0)
{
memset(ch, 0, sizeof(ch));
}
};

void insert(Trie* root, const string& s)
{
Trie* p=root;
for(char c: s)
{
int idx=c-'a';
if(!p->ch[idx]) p->ch[idx]=new Trie();
p=p->ch[idx];
p->cnt++;
}
p->isEnd=true;
}

bool find(Trie* root, const string& s)
{
Trie* p=root;
for(char c: s)
{
int u=c-'a';
if(!p->ch[u]) return false;
p=p->ch[u];
}
return p->isEnd;
}

int prefixCnt(Trie* root, const string& s)
{
Trie* p=root;
for(char c: s)
{
int u=c-'a';
if(!p->ch[u]) return 0;
p=p->ch[u];
}
return p->cnt;
}

// 先检查find(s),再erase(s)
bool erase(Trie* root, const string& s)
{
if(!find(root, s)) return false;

Trie* p=root;
vector<Trie*> path;
for(char c: s)
{
int u=c-'a';
p=p->ch[u];
path.push_back(p);
}
p->isEnd=false;
for(int i=(int)s.size()-1; i>=0; i--) path[i]->cnt--;
return true;
}
/*
How to ues:
Trie* root=new Trie();
insert(root, "abc");
insert(root, "ab");
insert(root, "abcd");
cout<<find(root, "abc")<<endl; // output 1
cout<<prefixCnt(root, "ab")<<endl; // output 3
erase(root, "abcd");
cout<<find(root, "abcd")<<endl; // output 0
*/