前言

由于本人实例还不能到cf青名,用不上jiangly哥哥的代码,暂时自己结合网上的代码和自己的理解写一个模板,日后再更新。

使用 pandoc My-acm-icpc-template.md -o My-acm-icpc-template.docx —toc —highlight-style=tango导出为word

头、编译、debug

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// O2 O3 优化
#pragma GCC optimize("O2")
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")

// 头文件
#include <bits/stdc++.h>

/*
// 如果是clang或者gcc 则使用以下头文件
#include <iostream>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <string>
#include <cmath>
#include <functional>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <climits>
*/

using namespace std;

// 常用宏定义
#define ll long long
#define ull unsigned long long
#define pii pair<int, int>
#define endl '\n'
#define pnt(x) cout<<#x<<'='<<(x)<<endl
#define pnt2(x, y) cout<<#x<<'='<<(x)<<','<<#y<<'='<<(y)<<endl

/*
// 快速输出stl的线性容器
template <typename T1, typename T2>
ostream &operator<<(ostream &o, const pair<T1, T2> &p)
{
return o << "<" << p.first << ", " << p.second << ">";
}

template <typename T>
typename enable_if<
!is_same<T, string>::value &&
is_same<decltype(begin(declval<T>())), decltype(end(declval<T>()))>::value,
ostream&
>::type
operator<<(ostream &o, const T &v)
{
o << "{";
for (auto it = begin(v); it != end(v); ++it)
o << (it == begin(v) ? "" : " ,") << *it;
return o << "}";
}
*/

int main()
{
// 关闭同步流 加快IO
ios::sync_with_stdio(false);
cin.tie(0);

return 0;
}

编译运行

  • 命令运行
1
g++ -std=c++17 -O2 -Wall -Wextra -o XXX XXX.cpp
1
./XXX < X.input > X.output

这里使用脚本

1
2
3
4
5
6
7
#!/usr/bin/env bash
#usage: ./run <filename>
#example: ./run a (will compile a.cpp, run with a.input, output to a.output, compare with a.ans)

g++ -std=c++17 -O2 -o "$1" "$1.cpp" || exit 1
./"$1" < "$1.input" > "$1.output"
diff -q "$1.output" "$1.ans" && echo "AC" || echo "WA"

一些常量

类型 范围 常量名
int $[-2^{31}, 2^{31}-1]$ 即 $[-2147483648, 2147483647]$ INT_MAX INT_MIN UINT_MAX
long long $[-2^{63}, 2^{63}-1]$ 即 $[-9223372036854775808, 9223372036854775807]$ LLONG_MAX LLONG_MIN ULLONG_MAX
float $[-3.4 \times 10^{38}, 3.4 \times 10^{38}]$ FLT_MAX FLT_MIN DBL_MAX DBL_MIN
char $[0, 255]$ CHAR_MAX CHAR_MIN UCHAR_MAX
double $[-1.7 \times 10^{-308}, 1.7 \times 10^{308}]$ DBL_MAX DBL_MIN
long double $[-1.1 \times 10^{-4932}, 1.1 \times 10^{4932}]$ LDBL_MAX LDBL_MIN

时间复杂度预估

这里时间限制为1s

复杂度 勉强能过的数据规模
$O(1)$ 任意(常数)
$O(\log n)$ $n\le 10^{18}$ 甚至更大
$O(\sqrt n)$ $n\le 10^{12}$
$O(n)$ $n\le 10^{6}$
$O(n\log n)$ $n\le 10^{5}$
$O(n^2)$ $n \le 3\times10^{4}$(上限取决于常数)
$O(n^3)$ $n \le 500$
$O(2^n)$ $n \le 20$
$O(n!)$ $n \le 10$

数据结构模板

大数高精度模拟

处理大数 $10^{1000}$ 级别的加减乘除

手写结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
struct BigInte
{
vector<int> d;
int sign;

BigInte(ll num=0){*this=num;}
BigInte(const string s){*this=s;}

BigInte& operator=(ll num)
{
d.clear(); sign=1;
if(num<0) sign=-1, num=-num;
if(num==0) d.push_back(0);
while(num)
{
d.push_back(num%10);
num/=10;
}
return *this;
}

BigInte& operator=(const string &s)
{
d.clear(); sign=1;
int start=0;
if(s[0]=='-') sign=-1, start=1;
for(int i=s.size()-1; i>=start; i--)
if(isdigit(s[i])) d.push_back(s[i]-'0');
trim(); // 去除前导0
return *this;
}

void trim()
{
while(!d.empty() && !d.back()) d.pop_back();
if(d.size()==1 && d[0]==0) sign=1;
}

string str() const
{
string s=(sign==-1? "-":"");
for(int i=(int)d.size()-1; i>=0; i--)s+=char('0'+d[i]);
return s;
}

bool absLess(const BigInte &b) const
{
if(d.size()!=b.d.size()) return d.size()<b.d.size();
for(int i=(int)d.size()-1; i>=0; i--)
if(d[i]!=b.d[i]) return d[i]<b.d[i];
return false;
}

bool operator<(const BigInte &b) const
{
if(sign!=b.sign) return sign<b.sign;
if(sign==1)return absLess(b);
else return b.absLess(*this);
}

bool operator==(const BigInte &b) const
{
return sign==b.sign && d==b.d;
}

BigInte operator+(const BigInte &b) const
{
if(sign==b.sign)
{
BigInte c;
c.sign=sign;
c.d.resize(max(d.size(), b.d.size())+1,0);
int carry=0;
for(size_t i=0; i<c.d.size(); i++)
{
int x=carry;
if(i<d.size()) x+=d[i];
if(i<b.d.size()) x+=b.d[i];
c.d[i]=x%10;
carry=x/10;
}
c.trim();
return c;
}
return *this-(-b);
}

BigInte operator-() const
{
BigInte c=*this;
if(!(d.size()==1 && d[0]==0)) c.sign=-c.sign;
return c;
}

BigInte operator-(const BigInte &b) const
{
if(sign!=b.sign) return *this+(-b);
if((sign==1 && *this<b) || (sign==-1 && b<*this)) return -(b-*this);

BigInte c;
c.sign=sign;
c.d.resize(d.size(),0);
int borrow=0;
for(size_t i=0; i<d.size(); i++)
{
int x=d[i]-borrow;
if(i<b.d.size()) x-=b.d[i];
if(x<0) x+=10, borrow=1;
else borrow=0;
c.d[i]=x;
}
c.trim();
return c;
}

BigInte operator*(const BigInte &b) const
{
BigInte c;
c.sign=sign*b.sign;
c.d.assign(d.size()+b.d.size(),0);
for(size_t i=0; i<d.size(); i++)
{
int carry=0;
for(size_t j=0; j<b.d.size() || carry; j++)
{
long long cur=c.d[i+j]+(long long)d[i]*(j<b.d.size()? b.d[j]:0)+carry;
c.d[i+j]=cur%10;
carry=cur/10;
}
}
c.trim();
return c;
}

// 向下取整
BigInte operator/(const BigInte &b) const
{
BigInte a=*this, div=b;
a.sign=div.sign=1;
if(a.absLess(div)) return BigInte(0);

BigInte cur=0, res;
res.d.resize(d.size());
for(int i=(int)d.size()-1; i>=0; i--)
{
cur.d.insert(cur.d.begin(),d[i]);
cur.trim();
int x=0, l=0, r=9;
while(l<=r)
{
int m=(l+r)/2;
BigInte t=div*m;
if(!cur.absLess(t)) x=m, l=m+1;
else r=m-1;
}
res.d[i]=x;
cur=cur-div*x;
}
res.sign=sign*b.sign;
res.trim();
return res;
}
};
/*
//例子
BigInte a("1234987329857423985794783259");
BigInte b("124098321759817239843279812374");
cout << (a+b).str() << endl;
cout << (a*b).str() << endl;
*/

语言自带

但是这里不推荐使用大数模拟来写,cpp的大数有__int128pythondouble也支持

  • __int128 的使用(范围:$[-2^{127}, 2^{127}-1]$)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
using i128 = __int128;

// IO
istream& operator>>(istream& is, i128& x)
{
string s; is>>s;
x=0;
int flag=1;
for(auto c: s)
{
if(c=='-') {flag=-1; continue;}
x=x*10+(c-'0');
}
x*=flag;
return is;
}

ostream& operator<<(ostream& os, i128 x)
{
if(x<0) {os<<'-'; x=-x;}
if(x>9) os<<(x/10);
os<<(int)(x%10);
return os;
}

// 加减乘除模
i128 add(i128 a, i128 b){return a+b;}
i128 sub(i128 a, i128 b){return a-b;}
i128 mul(i128 a, i128 b){return a*b;}
i128 div(i128 a, i128 b){return a/b;} // 向0取整
i128 mod(i128 a, i128 b){return a%b;}

// 向下取整除法
i128 floorDiv(i128 a, i128 b)
{
return a/b - (a%b!=0 && (a^b)<0);
}

// 开方(整数部分)
i128 sqrtI128(i128 x)
{
if(x<0) return -1;
if(x==0) return 0;
i128 l=1, r=1e18, ans=1;
while(l<=r)
{
i128 m=(l+r)>>1;
if(m<=x/m) ans=m, l=m+1;
else r=m-1;
}
return ans;
}

// 快速幂
i128 qpow(i128 a, i128 b)
{
i128 res=1;
while(b)
{
if(b&1) res=res*a;
a=a*a;
b>>=1;
}
return res;
}

// 带模快速幂
i128 qpowm(i128 a, i128 b, i128 mod)
{
i128 res=1%mod;
a%=mod;
while(b)
{
if(b&1) res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}

// GCD
i128 gcd(i128 a, i128 b)
{
if(b==0) return a;
return gcd(b, a%b);
}

// LCM
i128 lcm(i128 a, i128 b)
{
return a/gcd(a,b)*b;
}

// log2(整数部分)
int log2I128(i128 x)
{
if(x<=0) return -1;
int cnt=0;
while(x>1)
{
x>>=1;
cnt++;
}
return cnt;
}

// log10(整数部分)
int log10I128(i128 x)
{
if(x<=0) return -1;
int cnt=0;
while(x>=10)
{
x/=10;
cnt++;
}
return cnt;
}

/*
使用例子:
i128 a, b;
cin>>a>>b;
cout<<add(a,b)<<endl;
cout<<sqrtI128(a)<<endl;
cout<<gcd(a,b)<<endl;
cout<<log2I128(a)<<endl;
*/

并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
struct DSU
{
int n,cnt;
vector<int> fa, sz;

DSU(int _n)
{
n=_n;
cnt=_n; // 统计集合个数
fa.resize(n);
sz.resize(n, 1);
for(int i=0; i<n; i++) fa[i]=i;
}

int find(int x)
{
if(fa[x]!=x) fa[x]=find(fa[x]);
return fa[x];
}

void unite(int x, int y)
{
int ra=find(x), rb=find(y);
if(ra!=rb)
{
if(sz[ra]<sz[rb]) swap(ra, rb);
fa[rb]=ra;
sz[ra]+=sz[rb];
cnt--;
}
}
};

字符串哈希

需要频繁子串比较 / 多次查询:用哈希更方便,尤其是在线性区间查询时几乎必选。

双值hash

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
using ull = unsigned long long;
ull base=131;
ull mod1=212370440130137957, mod2=1e9+7;

ull hash1(string s)
{
int len=s.size();
ull ans=0;
for (int i=0; i<len; i++)ans=(ans*base+(ull)s[i])%mod1;
return ans;
}

ull hash2(string s)
{
int len=s.size();
ull ans=0;
for (int i=0; i<len; i++) ans=(ans*base+(ull)s[i])%mod2;
return ans;
}

bool cmp(const string s, const string t) {
bool f1=(hash1(s)!=hash1(t));
bool f2=(hash2(s)!=hash2(t));
return f1||f2;
}

子串hash匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct StringHash
{
using ull = unsigned long long;
static const ull base = 131;
vector<ull> h, p; // h 前缀哈希, p 幂次表

StringHash(const string &s)
{
int n=s.size();
h.assign(n+1, 0);
p.assign(n+1, 1);
for(int i=1; i<=n; i++)
{
h[i]=h[i-1]*base+(s[i-1]-'a'+1);
p[i]=p[i-1]*base;
}
}

// 查询子串 [l,r] 的哈希 (1-indexed)
ull get(int l, int r)
{
return h[r]-h[l-1]*p[r-l+1];
}
};

使用例子,判断字字串符串是否相等 :

  1. 这里先预处理字符串的hash前缀
  1. 然后再计算子串s[l,l+1,,,r-1,r]的hash,并对比

二叉堆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
struct BinaryHeap // 大根堆
{
vector<int> heap;

void push(int x)
{
heap.push_back(x);
int i=heap.size()-1;
while(i>1&&heap[i]>heap[i/2])
{
swap(heap[i],heap[i/2]);
i/=2;
}
}

void pop()
{
int n=heap.size();
heap[1]=heap[n-1];
heap.pop_back();
int i=1;
while(1)
{
int lgst=i;
int l=i*2, r=i*2+1;
if(l<n && heap[l]>heap[lgst]) lgst=l;
if(r<n && heap[r]>heap[lgst]) lgst=r;
if(lgst==i) break;
swap(heap[i],heap[lgst]);
i=lgst;
}
}

int top()
{
return heap[1];
}

void buildHeap(vector<int> &arr)
{
for(int i=arr.size()/2; i>=1; i--)
{
int j=i;
while(1)
{
int lgst=j;
int l=j*2, r=j*2+1;
if(l<arr.size() && arr[l]>arr[lgst]) lgst=l;
if(r<arr.size() && arr[r]>arr[lgst]) lgst=r;
if(lgst==j) break;
swap(arr[j],arr[lgst]);
j=lgst;
}
}
}
};
/*
用法:
BinaryHeap heap;
heap.push(1);
heap.push(2);
heap.push(3);
heap.pop();
cout << heap.top() << "\n";
vector<int> arr={3,212,33,44,15};
BinaryHeap heapp;
heapp.buildHeap(arr);
*/

树状数组

支持单点更新add(x, val), 前缀查询sum(x), 区间查询rangeSum(l, r)

  • 一维树状数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
const int N=200005;
ll bit[N];
int n;

inline int lowbit(int x){return x&(-x);}

inline void add(int x, ll v)
{
for(; x<=n; x+=lowbit(x)) bit[x]+=v;
}

inline ll sum(int x)
{
ll s=0;
for(; x>0; x-=lowbit(x)) s+=bit[x];
return s;
}

// [l,r] 区间
inline ll rangeSum(int l, int r)
{
return sum(r)-sum(l-1);
}

void build(vector<int> &a)
{
for(int i=1; i<=n; i++) add(i,a[i]);
}
/*
用法:
vector<int> a={0,1,2,3,4,5};
n=a.size();
build(a);
add(pos, delta);
cout << rangeSum(l, r) << "\n";
*/
  • 二维树状数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
const int N=2005;
ll bit2[N][N];
int n, m;

inline int lowbit(int x){return x&(-x);}

inline void add(int x, int y, ll v)
{
for(int i=x; i<=n; i+=lowbit(i))
for(int j=y; j<=m; j+=lowbit(j))
bit2[i][j]+=v;
}

inline ll sum(int x, int y)
{
ll s=0;
for(int i=x; i>0; i-=lowbit(i))
for(int j=y; j>0; j-=lowbit(j))
s+=bit2[i][j];
return s;
}

// [(x1,y1),(x2,y2)] 矩形
inline ll rangeSum(int x1, int y1, int x2, int y2)
{
return sum(x2,y2)-sum(x1-1,y2)-sum(x2,y1-1)+sum(x1-1,y1-1);
}
  • 支持单点加/区间加,区间和, 前缀第k小的一维树状数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
const int N=200005;
ll bit1[N], bit2[N];
int n;

inline int lowbit(int x){return x&(-x);}

inline void add1(int x, ll v)
{
for(; x<=n; x+=lowbit(x)) bit1[x]+=v;
}

inline ll sum1(int x)
{
ll s=0;
for(; x>0; x-=lowbit(x)) s+=bit1[x];
return s;
}

// 未加的区间和
inline ll rangeSum1(int l, int r)
{
return sum1(r)-sum1(l-1);
}

// 区间加--单点查
inline void rangeAdd1(int l, int r, ll v)
{
add1(l,v);
add1(r+1,-v);
}

inline ll pointQuery(int x)
{
return sum1(x);
}

// 区间加--区间和

inline void add2(ll *bit, int x, ll v)
{
for(; x<=n; x+=lowbit(x)) bit[x]+=v;
}

inline void rangeAdd2(int l, int r, ll v)
{
add2(bit1,l,v);
add2(bit1,r+1,-v);
add2(bit2,l,v*(l-1));
add2(bit2,r+1,-v*r);
}

inline ll sum2(int x)
{
ll s1=0, s2=0, t=x;
for(; x>0; x-=lowbit(x))
{
s1+=bit1[x];
s2+=bit2[x];
}
return s1*t-s2;
}

inline ll rangeSum2(int l, int r)
{
return sum2(r)-sum2(l-1);
}

/*
使用这个kth时
应该声明使用bit来表示频率数组,如
memset(bit,0,sizeof(bit));
n=5;
add(1,2); // 2个1
add(2,3); // 3个2
*/
inline ll kth(int k)
{
int pos=0;
for(int i=1<<20; i; i>>=1)
{
if(pos+i<=n && bit1[pos+i]<k)
{
k-=bit1[pos+i];
pos+=i;
}
}
return pos+1;
}

ST表

支持 $O(n\log n)$ 预处理,$O(1)$ 查询区间最值(不支持修改)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
const int N=100005;
const int LOG=17;
int st[N][LOG], lg[N];
int n;

void build(int a[])
{
for(int i=1; i<=n; i++) st[i][0]=a[i];
for(int j=1; j<LOG; j++)
for(int i=1; i+(1<<j)-1<=n; i++)
st[i][j]=max(st[i][j-1], st[i+(1<<(j-1))][j-1]);

lg[1]=0;
for(int i=2; i<=n; i++) lg[i]=lg[i/2]+1;
}

int query(int l, int r)
{
int k=lg[r-l+1];
return max(st[l][k], st[r-(1<<k)+1][k]);
}

线段树

支持 $O(\log n)$ 单点/区间修改和查询

基础区间和+懒标记

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
const int N=100005;
ll tree[N<<2], lazy[N<<2];
int n;

void pushup(int rt)
{
tree[rt]=tree[rt<<1]+tree[rt<<1|1];
}

void pushdown(int rt, int ln, int rn)
{
if(lazy[rt])
{
lazy[rt<<1]+=lazy[rt];
lazy[rt<<1|1]+=lazy[rt];
tree[rt<<1]+=lazy[rt]*ln;
tree[rt<<1|1]+=lazy[rt]*rn;
lazy[rt]=0;
}
}

void build(int l, int r, int rt, int a[])
{
if(l==r)
{
tree[rt]=a[l];
return;
}
int m=(l+r)>>1;
build(l, m, rt<<1, a);
build(m+1, r, rt<<1|1, a);
pushup(rt);
}

void update(int L, int R, ll val, int l, int r, int rt)
{
if(L<=l && r<=R)
{
tree[rt]+=val*(r-l+1);
lazy[rt]+=val;
return;
}
int m=(l+r)>>1;
pushdown(rt, m-l+1, r-m);
if(L<=m) update(L, R, val, l, m, rt<<1);
if(R>m) update(L, R, val, m+1, r, rt<<1|1);
pushup(rt);
}

ll query(int L, int R, int l, int r, int rt)
{
if(L<=l && r<=R) return tree[rt];
int m=(l+r)>>1;
pushdown(rt, m-l+1, r-m);
ll res=0;
if(L<=m) res+=query(L, R, l, m, rt<<1);
if(R>m) res+=query(L, R, m+1, r, rt<<1|1);
return res;
}

/*
使用:
int a[N];
build(1, n, 1, a);
update(l, r, val, 1, n, 1); // 区间[l,r]加val
ll ans=query(l, r, 1, n, 1); // 查询区间[l,r]和
*/

动态开点线段树

适用于值域很大但修改/查询次数少的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
const int N=2e7; // 节点数上限
struct Node
{
int ls, rs;
ll val, lazy;
}tr[N];
int root, idx;

void pushup(int rt)
{
tr[rt].val=tr[tr[rt].ls].val+tr[tr[rt].rs].val;
}

void pushdown(int rt, int ln, int rn)
{
if(!tr[rt].ls) tr[rt].ls=++idx;
if(!tr[rt].rs) tr[rt].rs=++idx;
if(tr[rt].lazy)
{
tr[tr[rt].ls].val+=tr[rt].lazy*ln;
tr[tr[rt].rs].val+=tr[rt].lazy*rn;
tr[tr[rt].ls].lazy+=tr[rt].lazy;
tr[tr[rt].rs].lazy+=tr[rt].lazy;
tr[rt].lazy=0;
}
}

void update(int L, int R, ll val, int l, int r, int &rt)
{
if(!rt) rt=++idx;
if(L<=l && r<=R)
{
tr[rt].val+=val*(r-l+1);
tr[rt].lazy+=val;
return;
}
int m=(l+r)>>1;
pushdown(rt, m-l+1, r-m);
if(L<=m) update(L, R, val, l, m, tr[rt].ls);
if(R>m) update(L, R, val, m+1, r, tr[rt].rs);
pushup(rt);
}

ll query(int L, int R, int l, int r, int rt)
{
if(!rt) return 0;
if(L<=l && r<=R) return tr[rt].val;
int m=(l+r)>>1;
pushdown(rt, m-l+1, r-m);
ll res=0;
if(L<=m) res+=query(L, R, l, m, tr[rt].ls);
if(R>m) res+=query(L, R, m+1, r, tr[rt].rs);
return res;
}

单调栈

维护一个单调递增或递减的栈,用于求每个元素左边/右边第一个比它大/小的元素

从左到右遍历,栈中维护单调性。若当前元素破坏单调性,则弹出栈顶直到满足条件。常见应用:最大矩形面积、接雨水等。

求左边第一个比它小的元素

1
2
3
4
5
6
7
8
9
10
11
12
13
vector<int> leftSmaller(vector<int> &a)
{
int n=a.size();
vector<int> ans(n);
stack<int> st; // 单调递增栈,存下标
for(int i=0; i<n; i++)
{
while(!st.empty() && a[st.top()]>=a[i]) st.pop();
ans[i]=st.empty()? -1: st.top();
st.push(i);
}
return ans;
}

最大矩形面积(柱状图)

给定 $n$ 个柱子的高度,求能构成的最大矩形面积。

思路:对每个柱子,找左右第一个比它矮的位置,宽度即为两位置之间距离。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
ll lgstRectangle(vector<int> &h)
{
int n=h.size();
vector<int> l(n), r(n);
stack<int> st;

// 左边第一个更小的
for(int i=0; i<n; i++)
{
while(!st.empty() && h[st.top()]>=h[i]) st.pop();
l[i]=st.empty()? 0: st.top()+1;
st.push(i);
}

while(!st.empty()) st.pop();

// 右边第一个更小的
for(int i=n-1; i>=0; i--)
{
while(!st.empty() && h[st.top()]>=h[i]) st.pop();
r[i]=st.empty()? n-1: st.top()-1;
st.push(i);
}

ll ans=0;
for(int i=0; i<n; i++)
ans=max(ans, (ll)h[i]*(r[i]-l[i]+1));
return ans;
}

单调队列

维护一个单调队列,用于求滑动窗口最值

队列保持单调性,队首为当前窗口的最值。每次移动窗口时,检查队首是否过期,队尾是否需要弹出。

滑动窗口最大值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
vector<int> maxSlidingWindow(vector<int> &a, int k)
{
int n=a.size();
vector<int> ans;
deque<int> q; // 存下标,单调递减

for(int i=0; i<n; i++)
{
// 移除过期元素
while(!q.empty() && q.front()<=i-k) q.pop_front();
// 维护单调性
while(!q.empty() && a[q.back()]<=a[i]) q.pop_back();
q.push_back(i);
if(i>=k-1) ans.push_back(a[q.front()]);
}
return ans;
}

单调队列优化DP

例:跳台阶,每次可跳 $[1,k]$ 步,第 $i$ 个台阶代价为 $c_i$,求最小总代价。

$dp[i]=\min(dp[j])+c[i],\ j\in[i-k,i-1]$,用单调队列维护区间最小值。

1
2
3
4
5
6
7
8
9
10
11
12
13
int n, k, c[N];
ll dp[N];
deque<int> q;

dp[0]=0;
q.push_back(0);
for(int i=1; i<=n; i++)
{
while(!q.empty() && q.front()<i-k) q.pop_front();
dp[i]=dp[q.front()]+c[i];
while(!q.empty() && dp[q.back()]>=dp[i]) q.pop_back();
q.push_back(i);
}

树链剖分

将树分解成若干条链,支持树上路径修改、查询,时间复杂度 $O(\log^2 n)$。

两次DFS,第一次求子树大小和重儿子,第二次划分重链。路径操作转化为链上的线段树操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
int n, a[N];
vector<int> tree[N];
int sz[N], dep[N], fa[N], son[N]; // 子树大小、深度、父节点、重儿子
int top[N], id[N], rk[N], cnt; // 链顶、DFS序、DFS序对应节点

void dfs1(int u, int f)
{
sz[u]=1; fa[u]=f; dep[u]=dep[f]+1;
for(int v: tree[u])
{
if(v==f) continue;
dfs1(v, u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}

void dfs2(int u, int t)
{
top[u]=t; id[u]=++cnt; rk[cnt]=u;
if(!son[u]) return;
dfs2(son[u], t); // 重儿子继承链顶
for(int v: tree[u])
if(v!=fa[u] && v!=son[u])
dfs2(v, v); // 轻儿子开启新链
}

// 线段树部分(维护和)
ll tree_seg[N<<2], lazy[N<<2];
void build(int l, int r, int rt)
{
if(l==r) {tree_seg[rt]=a[rk[l]]; return;}
int m=(l+r)>>1;
build(l, m, rt<<1); build(m+1, r, rt<<1|1);
tree_seg[rt]=tree_seg[rt<<1]+tree_seg[rt<<1|1];
}
void pushdown(int rt, int ln, int rn)
{
if(lazy[rt])
{
lazy[rt<<1]+=lazy[rt]; lazy[rt<<1|1]+=lazy[rt];
tree_seg[rt<<1]+=lazy[rt]*ln; tree_seg[rt<<1|1]+=lazy[rt]*rn;
lazy[rt]=0;
}
}
void update(int L, int R, ll val, int l, int r, int rt)
{
if(L<=l && r<=R) {tree_seg[rt]+=val*(r-l+1); lazy[rt]+=val; return;}
int m=(l+r)>>1;
pushdown(rt, m-l+1, r-m);
if(L<=m) update(L, R, val, l, m, rt<<1);
if(R>m) update(L, R, val, m+1, r, rt<<1|1);
tree_seg[rt]=tree_seg[rt<<1]+tree_seg[rt<<1|1];
}
ll query(int L, int R, int l, int r, int rt)
{
if(L<=l && r<=R) return tree_seg[rt];
int m=(l+r)>>1;
pushdown(rt, m-l+1, r-m);
ll res=0;
if(L<=m) res+=query(L, R, l, m, rt<<1);
if(R>m) res+=query(L, R, m+1, r, rt<<1|1);
return res;
}

// 路径修改
void pathUpdate(int u, int v, ll val)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u, v);
update(id[top[u]], id[u], val, 1, n, 1);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u, v);
update(id[u], id[v], val, 1, n, 1);
}

// 路径查询
ll pathQuery(int u, int v)
{
ll res=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u, v);
res+=query(id[top[u]], id[u], 1, n, 1);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u, v);
res+=query(id[u], id[v], 1, n, 1);
return res;
}

/*
使用:
dfs1(1, 0);
dfs2(1, 1);
build(1, n, 1);
pathUpdate(u, v, val);
ll ans=pathQuery(u, v);
*/

主席树(可持久化线段树)

支持查询历史版本的线段树,常用于求区间第k小

思路:每次修改不直接改原树,而是新建节点。利用函数式编程思想,不同版本共享未修改的子树。

区间第k小

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
const int N=2e5+5;
int n, m, a[N], b[N], cnt;
int root[N], ls[N*40], rs[N*40], sum[N*40];

int build(int l, int r)
{
int rt=++cnt;
if(l==r) return rt;
int m=(l+r)>>1;
ls[rt]=build(l, m);
rs[rt]=build(m+1, r);
return rt;
}

int update(int pre, int l, int r, int x)
{
int rt=++cnt;
ls[rt]=ls[pre]; rs[rt]=rs[pre]; sum[rt]=sum[pre]+1;
if(l==r) return rt;
int m=(l+r)>>1;
if(x<=m) ls[rt]=update(ls[pre], l, m, x);
else rs[rt]=update(rs[pre], m+1, r, x);
return rt;
}

int query(int u, int v, int l, int r, int k)
{
if(l==r) return l;
int m=(l+r)>>1;
int x=sum[ls[v]]-sum[ls[u]];
if(k<=x) return query(ls[u], ls[v], l, m, k);
else return query(rs[u], rs[v], m+1, r, k-x);
}

/*
使用:
// 离散化
for(int i=1; i<=n; i++) b[i]=a[i];
sort(b+1, b+n+1);
int len=unique(b+1, b+n+1)-b-1;
for(int i=1; i<=n; i++) a[i]=lower_bound(b+1, b+len+1, a[i])-b;

// 建树
root[0]=build(1, len);
for(int i=1; i<=n; i++)
root[i]=update(root[i-1], 1, len, a[i]);

// 查询[l,r]区间第k小
int ans=query(root[l-1], root[r], 1, len, k);
cout<<b[ans]<<endl; // 还原真实值
*/

分块算法

将序列分成 $\sqrt{n}$ 块,块内暴力,块间优化,时间复杂度 $O(n\sqrt{n})$。

思路:设块长 $B=\sqrt{n}$,预处理每块的信息。查询时,完整块直接用预处理结果,不完整块暴力。

区间加、区间查询和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
int n, B, a[N], belong[N]; // belong[i]表示i属于哪个块
ll sum[N], add[N]; // sum[i]为第i块的和,add[i]为第i块的加法标记

void init()
{
B=sqrt(n);
for(int i=1; i<=n; i++)
{
belong[i]=(i-1)/B+1;
sum[belong[i]]+=a[i];
}
}

void update(int l, int r, ll val)
{
if(belong[l]==belong[r]) // 同一块
{
for(int i=l; i<=r; i++) a[i]+=val, sum[belong[i]]+=val;
return;
}
// 左端不完整块
for(int i=l; belong[i]==belong[l]; i++) a[i]+=val, sum[belong[i]]+=val;
// 右端不完整块
for(int i=r; belong[i]==belong[r]; i--) a[i]+=val, sum[belong[i]]+=val;
// 中间完整块
for(int i=belong[l]+1; i<belong[r]; i++) add[i]+=val, sum[i]+=val*B;
}

ll query(int l, int r)
{
ll res=0;
if(belong[l]==belong[r])
{
for(int i=l; i<=r; i++) res+=a[i]+add[belong[i]];
return res;
}
for(int i=l; belong[i]==belong[l]; i++) res+=a[i]+add[belong[i]];
for(int i=r; belong[i]==belong[r]; i--) res+=a[i]+add[belong[i]];
for(int i=belong[l]+1; i<belong[r]; i++) res+=sum[i];
return res;
}

平衡树(Treap)

支持插入、删除、查询第k小、查询排名等操作,时间复杂度 $O(\log n)$。

结合BST和堆的性质,键值满足BST,优先级满足堆,通过旋转维护平衡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
struct Treap
{
struct Node
{
int l, r, val, key, size;
}t[N];
int cnt, root;

int newNode(int val)
{
t[++cnt]={0, 0, val, rand(), 1};
return cnt;
}

void update(int p)
{
t[p].size=t[t[p].l].size+t[t[p].r].size+1;
}

int merge(int x, int y)
{
if(!x || !y) return x+y;
if(t[x].key<t[y].key)
{
t[x].r=merge(t[x].r, y);
update(x);
return x;
}
else
{
t[y].l=merge(x, t[y].l);
update(y);
return y;
}
}

void split(int p, int val, int &x, int &y)
{
if(!p) {x=y=0; return;}
if(t[p].val<=val)
{
x=p;
split(t[p].r, val, t[p].r, y);
}
else
{
y=p;
split(t[p].l, val, x, t[p].l);
}
update(p);
}

void insert(int val)
{
int x, y;
split(root, val, x, y);
root=merge(merge(x, newNode(val)), y);
}

void erase(int val)
{
int x, y, z;
split(root, val, x, z);
split(x, val-1, x, y);
y=merge(t[y].l, t[y].r);
root=merge(merge(x, y), z);
}

int getRank(int val) // 查询val的排名
{
int x, y;
split(root, val-1, x, y);
int res=t[x].size+1;
root=merge(x, y);
return res;
}

int kth(int p, int k) // 查询第k小
{
while(p)
{
if(t[t[p].l].size+1==k) return t[p].val;
if(t[t[p].l].size>=k) p=t[p].l;
else k-=t[t[p].l].size+1, p=t[p].r;
}
return 0;
}

int getPre(int val) // 查询前驱
{
int x, y;
split(root, val-1, x, y);
int p=x;
while(t[p].r) p=t[p].r;
int res=t[p].val;
root=merge(x, y);
return res;
}

int getNext(int val) // 查询后继
{
int x, y;
split(root, val, x, y);
int p=y;
while(t[p].l) p=t[p].l;
int res=t[p].val;
root=merge(x, y);
return res;
}
};

算法基础

二分查找

注意数组必须有序!!!

1
2
3
4
5
6
7
8
9
10
11
12
bool check(ll x)
{
// 判断x是否满足条件
}

ll l=-1, r=LLONG_MAX;
while(l+1!=r)
{
ll m=(l+r)>>1;
if(check(m)) l=m;
else r=m;
}

二分查找(浮点数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
const double eps = 1e-9; // 精度
double l=0, r=1e9;
while(r-l>eps)
{
double m=(l+r)/2.0;
if(check(m)) l=m;
else r=m;
}

// 或者使用固定循环次数(推荐,更稳定)
double l = 0, r = 1e9;
for(int i=0; i<100; i++)
{
double m = (l+r)/2.0;
if(check(m)) l=m;
else r=m;
}

三分查找

用于在单峰函数(凸函数或凹函数)上寻找极值点,时间复杂度 $O(\log n)$ ,这里给出求最大值的代码,最小值改个条件即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// 求单峰函数的最大值点
double f(double x)
{
// 目标函数
}

const double eps = 1e-9;
double l=0, r=1e9;
while(r-l>eps)
{
double m1=l+(r-l)/3;
double m2=r-(r-l)/3;
if(f(m1)<f(m2)) l=m1;
else r=m2;
}
// 最大值点在 [l,r] 区间内,可以取 (l+r)/2

// 或者使用固定循环次数(推荐)
double l=0, r=1e9;
for(int i=0; i<100; i++)
{
double m1=l+(r-l)/3;
double m2=r-(r-l)/3;
if(f(m1)<f(m2)) l=m1;
else r=m2;
}

数论、几何、多项式

常用公式/结论

质因数分解相关

设 $n=\prod_{i=1}^{k}p_i^{q_i}$($p_i$ 为质因子,$q_i$ 为指数)

  • 因数个数:$d(n)=\prod_{i=1}^{k}(q_i+1)$
  • 因数和:$\sigma(n)=\prod{i=1}^{k}\frac{p_i^{q_i+1}-1}{p_i-1}=\prod{i=1}^{k}(1+p_i+p_i^2+\cdots+p_i^{q_i})$
  • 欧拉函数:$\varphi(n)=n\prod{i=1}^{k}(1-\frac{1}{p_i})=n\prod{i=1}^{k}\frac{p_i-1}{p_i}$
  • 莫比乌斯函数
    • 若存在 $q_i>1$,则 $\mu(n)=0$
    • 否则 $\mu(n)=(-1)^k$
  • 最大公约数:$\gcd(a,b)=\prod p_i^{\min(q_i^a, q_i^b)}$
  • 最小公倍数:$\text{lcm}(a,b)=\prod p_i^{\max(q_i^a, q_i^b)}$

基本定理

  • 裴蜀定理:对于任意整数 $a,b$,存在整数 $x,y$ 使得 $ax+by=\gcd(a,b)$
  • 费马小定理:若 $p$ 为质数,$\gcd(a,p)=1$,则 $a^{p-1}\equiv 1 \pmod{p}$
  • 欧拉定理:若 $\gcd(a,n)=1$,则 $a^{\varphi(n)}\equiv 1 \pmod{n}$
  • 威尔逊定理:$p$ 为质数当且仅当 $(p-1)! \equiv -1 \pmod{p}$
  • 中国剩余定理:若 $m_1,m_2,\ldots,m_k$ 两两互质,则同余方程组有唯一解

欧拉函数性质

  • $\varphi(1)=1$
  • $\varphi(p)=p-1$($p$ 为质数)
  • $\varphi(p^k)=p^{k-1}(p-1)$($p$ 为质数)
  • 若 $\gcd(m,n)=1$,则 $\varphi(mn)=\varphi(m)\varphi(n)$(积性函数)
  • $\sum_{d|n}\varphi(d)=n$

莫比乌斯函数性质

  • $\mu(1)=1$
  • 若 $\gcd(m,n)=1$,则 $\mu(mn)=\mu(m)\mu(n)$(积性函数)
  • $\sum_{d|n}\mu(d)=\begin{cases}1 & n=1 \ 0 & n>1\end{cases}$
  • 莫比乌斯反演
    • $f(n)=\sum{d|n}g(d) \Leftrightarrow g(n)=\sum{d|n}\mu(d)f(\frac{n}{d})$
    • $f(n)=\sum{n|d}g(d) \Leftrightarrow g(n)=\sum{n|d}\mu(\frac{d}{n})f(d)$

同余性质

  • $(a+b)\bmod m=(a\bmod m+b\bmod m)\bmod m$
  • $(a-b)\bmod m=(a\bmod m-b\bmod m)\bmod m$
  • $(a\times b)\bmod m=(a\bmod m\times b\bmod m)\bmod m$
  • $a\equiv b \pmod{m} \Rightarrow ac\equiv bc \pmod{m}$
  • 若 $\gcd(c,m)=1$,则 $ac\equiv bc \pmod{m} \Rightarrow a\equiv b \pmod{m}$

组合数性质

  • $C_n^m=C_n^{n-m}$
  • $Cn^m=C{n-1}^{m-1}+C_{n-1}^{m}$(杨辉三角)
  • $\sum_{i=0}^{n}C_n^i=2^n$
  • Lucas定理:$Cn^m \equiv C{n/p}^{m/p} \cdot C_{n\bmod p}^{m\bmod p} \pmod{p}$($p$ 为质数)

其他常用结论

  • 质数个数:不超过 $n$ 的质数约有 $\frac{n}{\ln n}$ 个
  • 1到n的质数个数:$\pi(n) \sim \frac{n}{\ln n}$

快速幂

  • 不带模的快速幂
1
2
3
4
5
6
7
8
9
10
11
12
// a^b
ll qpow(ll a, ll b)
{
ll res=1;
while(b)
{
if(b&1) res=res*a;
a=a*a;
b>>=1;
}
return res;
}
  • 带模的快速幂
1
2
3
4
5
6
7
8
9
10
11
12
13
// a^b % mod
ll qpowm(ll a, ll b, ll mod)
{
ll res=1%mod;
a%=mod;
while(b)
{
if(b&1) res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}

矩阵快速幂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
const int mod=1e9+7;
vector<vector<ll>> matrix_mul(vector<vector<ll>> &a,vector<vector<ll>> &b)
{
int n=a.size(), p=b.size(), m=b[0].size();
vector<vector<ll>> c(n,vector<ll>(m,0));
for(int i=0; i<n; i++)
{
for(int k=0; k<p; k++)
{
if(a[i][k]==0) continue;
for(int j=0; j<m; j++)
c[i][j]=(c[i][j]+1LL*a[i][k]*b[k][j])%mod;
}
}
return c;
}

vector<vector<ll>> qpow(vector<vector<ll>> &a, ll p)
{
int n=a.size();
vector<vector<ll>> res(n,vector<ll>(n,0));
for(int i=0; i<n; i++) res[i][i]=1;
while(p)
{
if(p&1) res=matrix_mul(res,a);
a=matrix_mul(a,a);
p>>=1;
}
return res;
}

扩展欧几里得

1
2
3
4
5
6
7
8
9
10
11
pii exgcd(ll a, ll b, ll c)
{
if(b==0)
{
if(a==0)return {0,0};
if(c%a!=0)return {0,0};
return {c/a,0};
}
auto [x1,y1]=exgcd(b,a%b,c);
return {y1,x1-(a/b)*y1};
}

乘法逆元

乘法逆元存在的条件:$\gcd(a, m) = 1$,即 $a$ 与 $m$ 互质

快速幂求逆元(费马小定理)

限制:要求 $m$ 为质数,且 $\gcd(a, m) = 1$

费马小定理:当 $p$ 为质数时,$a^{p-1} \equiv 1 \pmod{p}$,因此 $a^{-1} \equiv a^{p-2} \pmod{p}$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

ll mod_inverse(ll a, ll m)
{
return qpow(a, m - 2, m);
}

扩展欧几里得求逆元

通用方法,适用于任意模数

原理:通过扩展欧几里得算法求解 $ax + my = 1$,得到 $ax \equiv 1 \pmod{m}$,即 $x$ 为 $a$ 在模 $m$ 意义下的逆元。当 $\gcd(a, m) \neq 1$ 时逆元不存在。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b==0)
{
x=1, y=0;
return a;
}
ll g=exgcd(b, a%b, y, x);
y-=a/b*x;
return g;
}

ll mod_inverse(ll a, ll m)
{
ll x, y;
ll g=exgcd(a, m, x, y);
if(g!=1) return -1; // 逆元不存在
return (x%m+m)%m;
}

线性求逆元

求1到n所有数的逆元,限制:要求 $\mathbf{mod}$ 为质数

时间复杂度 $O(n)$,适用于需要大量逆元的情况

1
2
3
4
5
6
7
8
vector<ll> line_inv(int n, ll mod)
{
vector<ll> inv(n+1);
inv[1]=1;
for(int i=2; i<=n; i++)
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
return inv;
}

中国剩余定理

解为:

其中 $M=\prod_{i=1}^{n} m_i$,$M_i=M/m_i$,$y_i \equiv M_i^{-1} \pmod{m_i}$

模数两两互质

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0)
{
x=1, y=0;
return a;
}
ll g=exgcd(b,a%b,y,x);
y-=a/b*x;
return g;
}

ll CRT(vector<ll> &a, vector<ll> &m)
{
ll M=1, res=0;
for(auto &mi: m) M*=mi;
for(int i=0; i<a.size(); i++)
{
ll Mi=M/m[i], x, y;
exgcd(Mi,m[i],x,y);
x=(x%M+M)%M;
res=(res+a[i]*Mi%M*x%M)%M;
}
return res;
}

模数不两两互质

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0)
{
x=1, y=0;
return a;
}
ll g=exgcd(b,a%b,y,x);
y-=a/b*x;
return g;
}

ll exCRT(vector<ll> &a, vector<ll> &m)
{
ll x=a[0], M=m[0];
for(int i=1; i<a.size(); i++)
{
ll a2=a[i], m2=m[i];
ll c=(a2-x%m2+m2)%m2, g=__gcd(M,m2);
if(c%g!=0) return -1;
// solve k*M=c(mod m2)
ll k, t;
exgcd(M,m2,k,t);
k=k*(c/g)%(m2/g);
k=(k%(m2/g)+(m2/g))%(m2/g);
x=x+k*M;
M=M/g*m2;
x%=M;
}
}

质因数分解

基本方法

时间复杂度 $O(\sqrt{n})$

1
2
3
4
5
6
7
8
9
10
11
12
vector<int> p, q;
for(int i=2; i*i<=n; i++)if(n%i==0)
{
int j=0;
while(n%i==0)j++, n/=i;
p.push_back(i); q.push_back(j);
}
if(n>1)
{
p.push_back(n);
q.push_back(1);
}

快速方法

时间复杂度 $O(\frac{\sqrt{n}}{\ln n})$,需要预处理质数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
const int N = 1e6+5;
vector<int> primes;
bool not_prime[N];

void sieve(int n)
{
for(int i=2; i<=n; i++)
{
if(!not_prime[i]) primes.push_back(i);
for(int p: primes)
{
if(i*p>n) break;
not_prime[i*p]=true;
if(i%p==0) break;
}
}
}

// 质因数分解
vector<pii> factorize(ll n)
{
vector<pii> res; // {质因子, 指数}
for(int p: primes)
{
if(1LL*p*p>n) break;
if(n%p==0)
{
int cnt=0;
while(n%p==0) n/=p, cnt++;
res.push_back({p, cnt});
}
}
if(n>1) res.push_back({n, 1});
return res;
}

Pollard’s Rho 算法(大数分解)

时间复杂度 $O(n^{1/4})$,适用于 $n \le 10^{18}$ 的大数分解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
ll gcd(ll a, ll b)
{
return b? gcd(b, a%b): a;
}

// 快速乘(防止溢出)
ll mul(ll a, ll b, ll m)
{
return (__int128)a*b%m;
}

// 快速幂
ll qpow(ll a, ll b, ll m)
{
ll res=1;
while(b)
{
if(b&1) res=mul(res, a, m);
a=mul(a, a, m);
b>>=1;
}
return res;
}

// Miller-Rabin 素性测试
bool miller_rabin(ll n)
{
if(n<3 || n%2==0) return n==2;
ll u=n-1, t=0;
while(u%2==0) u/=2, t++;
ll test[]={2,3,5,7,11,13,17,19,23,29,31,37};
for(ll a: test)
{
if(n==a) return true;
ll v=qpow(a, u, n);
if(v==1 || v==n-1) continue;
for(int j=0; j<t; j++)
{
v=mul(v, v, n);
if(v==n-1) break;
}
if(v!=n-1) return false;
}
return true;
}

// Pollard's Rho 算法找因子
ll pollard_rho(ll n)
{
ll c=rand()%(n-1)+1;
ll x=rand()%n, y=x, d=1;
for(int i=1; d==1; i<<=1)
{
y=x;
for(int j=0; j<i; j++)
{
x=(mul(x, x, n)+c)%n;
d=gcd(abs(x-y), n);
if(d>1) break;
}
}
return d==n? pollard_rho(n): d;
}

// 递归分解
void factorize(ll n, map<ll,int> &factors)
{
if(n==1) return;
if(miller_rabin(n))
{
factors[n]++;
return;
}
ll d=pollard_rho(n);
factorize(d, factors);
factorize(n/d, factors);
}

/*
使用例子:
map<ll,int> factors;
factorize(n, factors);
for(auto [p, cnt]: factors)
cout<<p<<"^"<<cnt<<" ";
*/

线筛质数、欧拉、莫比乌斯

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
const int N = 1e6+5;
const int MOD = 1000000007;
vector<ll> pri;
bool not_prime[N];
ll phi[N],mu[N];


void pre(ll n)
{
phi[1]=1;
mu[1]=1;
for(ll i=2; i<=n; ++i)
{
if(!not_prime[i])
{
pri.push_back(i);
phi[i]=i-1;
mu[i]=-1;
}
for(ll p : pri)
{
if(i*p>n) break;
not_prime[i*p]=true;
if (i%p == 0)
{
phi[i*p]=phi[i]*p;
mu[i*p]=0;
break;
}
phi[i*p]=phi[i]*phi[p];
mu[i*p]=-mu[i];
}
}
}

欧拉函数

欧拉函数 $\varphi(n)$ 表示小于等于 $n$ 的正整数中与 $n$ 互质的数的个数

性质:$\varphi(n) = n \prod_{p|n} (1-\frac{1}{p})$,其中 $p$ 为 $n$ 的质因子

单个数

时间复杂度 $O(\sqrt{n})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
ll phi(ll n)
{
ll res=n;
for(ll i=2; i*i<=n; i++)
{
if(n%i==0)
{
res=res/i*(i-1);
while(n%i==0) n/=i;
}
}
if(n>1) res=res/n*(n-1);
return res;
}

批量

时间复杂度 $O(n)$,见上面的”线筛质数、欧拉、莫比乌斯”部分

莫比乌斯函数

莫比乌斯函数 $\mu(n)$ 的定义:

  • $\mu(1) = 1$
  • $\mu(n) = 0$,若 $n$ 含有平方因子
  • $\mu(n) = (-1)^k$,若 $n$ 为 $k$ 个不同质因子的乘积

单个数

时间复杂度 $O(\sqrt{n})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int mu(ll n)
{
int res=1, cnt=0;
for(ll i=2; i*i<=n; i++)
{
if(n%i==0)
{
int t=0;
while(n%i==0) n/=i, t++;
if(t>1) return 0; // 含有平方因子
cnt++; // 质因子个数
}
}
if(n>1) cnt++;
return (cnt&1)? -res: res;
}

批量

时间复杂度 $O(n)$,见上面的”线筛质数、欧拉、莫比乌斯”部分

快速组合数

预处理

预处理阶乘和阶乘逆元(适用于多次查询)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
const ll MOD = 1e9+7;
const int MAXN = 1e6+5;
vector<ll> fac(MAXN), ifac(MAXN);

ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

void pre(int n)
{
fac[0]=1;
for(int i=1; i<=n; i++) fac[i]=(fac[i-1]*i)%MOD;
ifac[n]=qpow(fac[n], MOD-2, MOD);
for(int i=n-1; i>=0; i--) ifac[i]=(ifac[i+1]*(i+1))%MOD;
}

// C_n^m
ll C(int n, int m)
{
if(m<0 || m>n) return 0;
return fac[n]*ifac[m]%MOD*ifac[n-m]%MOD;
}

卢卡斯定理

适用于大数组合数模小质数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

ll small(ll m, ll n, ll p)
{
if(n<0 || n>m) return 0;
if(n==0 || n==m) return 1;
ll a=1, b=1;
for(int i=1; i<=n; i++)
{
a=a*(m-i+1)%p;
b=(b*i)%p;
}
return a*qpow(b, p-2, p)%p;
}

ll Lucas(ll m, ll n, ll p)
{
if(n==0) return 1;
return small(m%p, n%p, p)*Lucas(m/p, n/p, p)%p;
}

欧拉降幂

用于计算 $a^b \bmod m$,当 $b$ 非常大时(可能无法直接存储)

通用模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
ll phi(ll n)
{
ll res=n;
for(ll i=2; i*i<=n; i++)
{
if(n%i==0)
{
res=res/i*(i-1);
while(n%i==0) n/=i;
}
}
if(n>1) res=res/n*(n-1);
return res;
}

ll qpow(ll a, ll b, ll m)
{
ll res=1%m;
a%=m;
while(b)
{
if(b&1) res=res*a%m;
a=a*a%m;
b>>=1;
}
return res;
}

// 欧拉降幂
ll euler_pow(ll a, ll b, ll m)
{
if(m==1) return 0;
ll ph=phi(m);
if(b<ph) return qpow(a, b, m);
return qpow(a, b%ph+ph, m);
}

// 当b以字符串形式给出时
ll euler_pow_str(ll a, string b, ll m)
{
if(m==1) return 0;
ll ph=phi(m);

ll mod_b=0;
bool large=false;

for(char c: b)
{
mod_b=mod_b*10+(c-'0');
if(mod_b>=ph) large=true;
mod_b%=ph;
}

// 判断原始b是否>=ph(需要比较字符串)
if(!large)
{
ll realb=0;
for(char c: b)
{
realb=realb*10+(c-'0');
if(realb>=ph) {large=true; break;}
}
}

if(large) return qpow(a, mod_b+ph, m);

ll realb=0;
for(char c: b) realb=realb*10+(c-'0');
return qpow(a, realb, m);
}

/*
使用例子:
cout<<euler_pow(2, 1000000000, 1000000007)<<endl;

// 若b是字符串
string b="123456789012345678901234567890";
cout<<euler_pow_str(2, b, 1000000007)<<endl;
*/

幂塔问题(扩展)

计算 $a^{a^{a^{\cdots}}} \bmod m$(递归幂塔)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
ll power_tower(ll a, ll cnt, ll m)
{
if(m==1) return 0;
if(cnt==0) return 1%m;
if(a==1) return 1%m;

ll ph=phi(m);
ll exp=power_tower(a, cnt-1, ph);

// 判断 a^exp 是否 >= m
bool large=false;
ll tmp=1;
for(ll i=0; i<exp; i++)
{
tmp*=a;
if(tmp>=m) {large=true; break;}
}

if(large) return qpow(a, exp%ph+ph, m);
return qpow(a, exp, m);
}

/*
计算 2^2^2^2 mod 1000000007
cout<<power_tower(2, 4, 1000000007)<<endl;
*/

高斯消元

解 $n$ 元线性方程组 $Ax=b$,其中 $A$ 是 $n\times n$ 矩阵。

通过初等行变换将增广矩阵 $[A|b]$ 化为上三角形式,再回代求解。时间复杂度 $O(n^3)$。

浮点数版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
const double eps=1e-9;
int n;
double a[N][N]; // 增广矩阵 [A|b]

// 返回:0-唯一解,1-无解,2-无穷多解
int gauss()
{
int r=0; // 当前处理到第几行
for(int c=0; c<n; c++) // 枚举列
{
int t=r;
for(int i=r; i<n; i++) // 找主元
if(fabs(a[i][c])>fabs(a[t][c])) t=i;

if(fabs(a[t][c])<eps) continue; // 主元为0,跳过此列

for(int i=c; i<=n; i++) swap(a[t][i], a[r][i]); // 交换行
for(int i=n; i>=c; i--) a[r][i]/=a[r][c]; // 首项化为1

for(int i=r+1; i<n; i++) // 消元
if(fabs(a[i][c])>eps)
for(int j=n; j>=c; j--)
a[i][j]-=a[r][j]*a[i][c];
r++;
}

if(r<n)
{
for(int i=r; i<n; i++)
if(fabs(a[i][n])>eps) return 1;
return 2;
}

// 回代求解
for(int i=n-1; i>=0; i--)
for(int j=i+1; j<n; j++)
a[i][n]-=a[i][j]*a[j][n];

return 0; // 唯一解在 a[i][n]
}

整数版本(模意义)

求解 $Ax\equiv b \pmod{p}$,常用于求线性递推式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
ll a[N][N];
int n;
ll MOD;

ll qpow(ll a, ll b, ll p)
{
ll res=1; a%=p;
while(b) {if(b&1) res=res*a%p; a=a*a%p; b>>=1;}
return res;
}

int gauss()
{
int r=0;
for(int c=0; c<n; c++)
{
int t=r;
for(int i=r; i<n; i++)
if(a[i][c]) {t=i; break;}

if(!a[t][c]) continue;

for(int i=c; i<=n; i++) swap(a[t][i], a[r][i]);
ll inv=qpow(a[r][c], MOD-2, MOD);
for(int i=c; i<=n; i++) a[r][i]=a[r][i]*inv%MOD;

for(int i=r+1; i<n; i++)
if(a[i][c])
{
ll t=a[i][c];
for(int j=c; j<=n; j++)
a[i][j]=(a[i][j]-t*a[r][j]%MOD+MOD)%MOD;
}
r++;
}

if(r<n)
{
for(int i=r; i<n; i++)
if(a[i][n]) return 1;
return 2;
}

for(int i=n-1; i>=0; i--)
for(int j=i+1; j<n; j++)
a[i][n]=(a[i][n]-a[i][j]*a[j][n]%MOD+MOD)%MOD;

return 0;
}

求行列式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// 浮点数版
double det(vector<vector<double>> &a)
{
int n=a.size();
double res=1;
for(int i=0; i<n; i++)
{
int k=i;
for(int j=i+1; j<n; j++)
if(fabs(a[j][i])>fabs(a[k][i])) k=j;
if(fabs(a[k][i])<eps) return 0;
if(k!=i) res=-res, swap(a[i], a[k]);
res*=a[i][i];
for(int j=i+1; j<n; j++) a[i][j]/=a[i][i];
for(int j=i+1; j<n; j++)
for(int k=i+1; k<n; k++)
a[j][k]-=a[j][i]*a[i][k];
}
return res;
}

// 整数版(模意义)
ll det_mod(ll a[][N], int n, ll p)
{
ll res=1;
for(int i=0; i<n; i++)
{
int k=i;
for(int j=i+1; j<n; j++)
if(a[j][i]) {k=j; break;}
if(!a[k][i]) return 0;
if(k!=i) res=p-res, swap(a[i], a[k]);
res=res*a[i][i]%p;
ll inv=qpow(a[i][i], p-2, p);
for(int j=i+1; j<n; j++)
{
ll t=a[j][i]*inv%p;
for(int k=i; k<n; k++)
a[j][k]=(a[j][k]-t*a[i][k]%p+p)%p;
}
}
return res;
}

线性基

用于处理异或线性组合问题,可以快速求解异或最大值、判断是否能异或出某个数等。

其中线性基是原数组的一个子集,满足:

  1. 原数组任意子集的异或和都能由线性基的某个子集异或得到
  2. 线性基中任意子集的异或和都不为0
  3. 线性基中元素个数唯一且最少

时间复杂度:插入 $O(\log V)$,查询 $O(\log V)$,其中 $V$ 为值域。

基本操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
struct LB
{
ll d[65]; // d[i]表示最高位为i的数
int cnt; // 线性基中元素个数
bool zero; // 是否能异或出0

LB() {memset(d, 0, sizeof(d)); cnt=0; zero=false;}

bool insert(ll x)
{
for(int i=62; i>=0; i--)
{
if(!(x&(1LL<<i))) continue;
if(!d[i]) {d[i]=x; cnt++; return true;}
x^=d[i];
}
zero=true; // 异或出0,说明线性相关
return false;
}

ll queryMax() // 查询最大异或和
{
ll res=0;
for(int i=62; i>=0; i--)
if((res^d[i])>res) res^=d[i];
return res;
}

ll queryMin() // 查询最小异或和(非0)
{
if(zero) return 0;
for(int i=0; i<=62; i++)
if(d[i]) return d[i];
return 0;
}

bool check(ll x) // 判断能否异或出x
{
for(int i=62; i>=0; i--)
{
if(!(x&(1LL<<i))) continue;
if(!d[i]) return false;
x^=d[i];
}
return true;
}
};

查询第k小异或和

先将线性基改为简化阶梯形(上三角),然后按二进制枚举。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
struct LB
{
ll d[65], p[65];
int cnt;
bool zero;

LB() {memset(d, 0, sizeof(d)); cnt=0; zero=false;}

bool insert(ll x)
{
for(int i=62; i>=0; i--)
{
if(!(x&(1LL<<i))) continue;
if(!d[i]) {d[i]=x; cnt++; return true;}
x^=d[i];
}
zero=true;
return false;
}

void rebuild() // 重构为简化阶梯形
{
for(int i=62; i>=0; i--)
for(int j=i-1; j>=0; j--)
if(d[i]&(1LL<<j)) d[i]^=d[j];

cnt=0;
for(int i=0; i<=62; i++)
if(d[i]) p[cnt++]=d[i];
}

ll queryKth(ll k) // 查询第k小(k从1开始)
{
if(zero) k--; // 如果能异或出0,0是最小的
if(k>=(1LL<<cnt)) return -1; // k太大

ll res=0;
for(int i=0; i<cnt; i++)
if(k&(1LL<<i)) res^=p[i];
return res;
}
};

合并两个线性基

1
2
3
4
5
6
7
LB merge(LB a, LB b)
{
LB c=a;
for(int i=62; i>=0; i--)
if(b.d[i]) c.insert(b.d[i]);
return c;
}

例子:区间异或最大值

给定数组 $a[1\ldots n]$,$q$ 次查询,每次询问 $[l,r]$ 中选若干个数异或的最大值。

思路:对每个右端点 $r$ 维护一个线性基,查询时直接用对应线性基。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
int n, q;
ll a[N];
LB lb[N];

void build()
{
for(int i=1; i<=n; i++)
{
lb[i]=lb[i-1];
lb[i].insert(a[i]);
}
}

ll query(int l, int r)
{
LB tmp;
for(int i=62; i>=0; i--)
{
if(!lb[r].d[i]) continue;
// 检查d[i]是否在[l,r]中
// 这里简化处理,直接返回最大值
tmp.d[i]=lb[r].d[i];
}
return tmp.queryMax();
}

平面几何

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
const double eps = 1e-9;
const double PI = acos(-1.0);

int sgn(double x)
{
if(fabs(x)<eps) return 0;
return x<0? -1: 1;
}

struct Point
{
double x, y;
Point(double x=0, double y=0):x(x), y(y) {}
Point operator+(const Point &p) const {return Point(x+p.x, y+p.y);}
Point operator-(const Point &p) const {return Point(x-p.x, y-p.y);}
Point operator*(double k) const {return Point(x*k, y*k);}
Point operator/(double k) const {return Point(x/k, y/k);}
bool operator<(const Point &p) const {return sgn(x-p.x)<0 || (sgn(x-p.x)==0 && sgn(y-p.y)<0);}
bool operator==(const Point &p) const {return sgn(x-p.x)==0 && sgn(y-p.y)==0;}
double len() {return sqrt(x*x+y*y);}
double len2() {return x*x+y*y;}
Point unit() {return *this/len();} // 单位向量
Point rotate(double ang) {return Point(x*cos(ang)-y*sin(ang), x*sin(ang)+y*cos(ang));} // 逆时针旋转
};

typedef Point Vector;

double dot(const Vector &a, const Vector &b) {return a.x*b.x+a.y*b.y;} // 点积
double cross(const Vector &a, const Vector &b) {return a.x*b.y-a.y*b.x;} // 叉积
double angle(const Vector &a, const Vector &b) {return acos(dot(a,b)/a.len()/b.len());} // 夹角

// 距离
double dis(const Point &a, const Point &b) {return (a-b).len();}
double dis2(const Point &a, const Point &b) {return (a-b).len2();}

struct Line
{
Point p, v; // 点+方向向量
Line() {}
Line(Point p, Vector v):p(p), v(v) {}
Point point(double t) {return p+v*t;} // 直线上的点
Line move(double d) {return Line(p+v.rotate(PI/2).unit()*d, v);} // 平移
};

// 点在直线上的投影
Point project(const Line &l, const Point &p)
{
return l.p+l.v*(dot(l.v, p-l.p)/l.v.len2());
}

// 点到直线距离
double dis_to_line(const Point &p, const Line &l)
{
return fabs(cross(l.v, p-l.p))/l.v.len();
}

// 点到线段距离
double dis_to_seg(const Point &p, const Point &a, const Point &b)
{
if(a==b) return dis(p, a);
Vector v1=b-a, v2=p-a, v3=p-b;
if(sgn(dot(v1,v2))<0) return v2.len();
if(sgn(dot(v1,v3))>0) return v3.len();
return fabs(cross(v1,v2))/v1.len();
}

// 两直线交点
Point line_intersection(const Line &a, const Line &b)
{
Vector u=a.p-b.p;
double t=cross(b.v, u)/cross(a.v, b.v);
return a.p+a.v*t;
}

// 多边形面积(有向面积)
double polygon_area(vector<Point> &p)
{
double area=0;
int n=p.size();
for(int i=0; i<n; i++)
area+=cross(p[i], p[(i+1)%n]);
return area/2;
}

// 点在多边形内(射线法)
int point_in_polygon(const Point &p, vector<Point> &poly)
{
int n=poly.size(), cnt=0;
for(int i=0; i<n; i++)
{
Point a=poly[i], b=poly[(i+1)%n];
if(sgn(cross(a-p, b-p))==0 && sgn(dot(a-p, b-p))<=0) return 0; // 在边上
int k=sgn(cross(b-a, p-a));
int u=sgn(a.y-p.y), v=sgn(b.y-p.y);
if(k>0 && u<0 && v>=0) cnt++;
if(k<0 && v<0 && u>=0) cnt--;
}
return cnt!=0; // 0:外部, 1:内部
}

// 圆
struct Circle
{
Point c;
double r;
Circle() {}
Circle(Point c, double r):c(c), r(r) {}
Point point(double a) {return Point(c.x+cos(a)*r, c.y+sin(a)*r);}
};

// 两圆交点
vector<Point> circle_intersection(Circle c1, Circle c2)
{
vector<Point> res;
double d=dis(c1.c, c2.c);
if(sgn(d)==0) return res; // 重合
if(sgn(d-c1.r-c2.r)>0) return res; // 相离
if(sgn(d-fabs(c1.r-c2.r))<0) return res; // 内含
double a=angle(c2.c-c1.c, Point(1,0));
double da=acos((c1.r*c1.r+d*d-c2.r*c2.r)/(2*c1.r*d));
Point p1=c1.point(a-da), p2=c1.point(a+da);
res.push_back(p1);
if(p1==p2) return res;
res.push_back(p2);
return res;
}

立体几何

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
struct Point3
{
double x, y, z;
Point3(double x=0, double y=0, double z=0):x(x), y(y), z(z) {}
Point3 operator+(const Point3 &p) const {return Point3(x+p.x, y+p.y, z+p.z);}
Point3 operator-(const Point3 &p) const {return Point3(x-p.x, y-p.y, z-p.z);}
Point3 operator*(double k) const {return Point3(x*k, y*k, z*k);}
Point3 operator/(double k) const {return Point3(x/k, y/k, z/k);}
double len() {return sqrt(x*x+y*y+z*z);}
double len2() {return x*x+y*y+z*z;}
};

typedef Point3 Vector3;

double dot(const Vector3 &a, const Vector3 &b) {return a.x*b.x+a.y*b.y+a.z*b.z;}

Vector3 cross(const Vector3 &a, const Vector3 &b)
{
return Vector3(a.y*b.z-a.z*b.y, a.z*b.x-a.x*b.z, a.x*b.y-a.y*b.x);
}

double dis(const Point3 &a, const Point3 &b) {return (a-b).len();}

// 平面(用法向量表示)
struct Plane
{
Point3 p, n; // 点+法向量
Plane() {}
Plane(Point3 p, Vector3 n):p(p), n(n) {}
};

// 点到平面距离
double dis2plane(const Point3 &p, const Plane &pl)
{
return fabs(dot(p-pl.p, pl.n))/pl.n.len();
}

// 直线与平面交点
Point3 lp_inter(Point3 p, Vector3 v, Plane pl)
{
double t=dot(pl.n, pl.p-p)/dot(pl.n, v);
return p+v*t;
}

// 四面体体积
double tet_vol(Point3 a, Point3 b, Point3 c, Point3 d)
{
return fabs(dot(cross(b-a, c-a), d-a))/6.0;
}

凸包算法

解决二维平面上的点集凸包问题,即求出包含所有点的最小凸多边形

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
struct Point
{
double x, y;
Point(double x=0, double y=0):x(x), y(y) {}
Point operator-(const Point &p) const {return Point(x-p.x, y-p.y);}
bool operator<(const Point &p) const {return x<p.x || (x==p.x && y<p.y);}
};

double cross(const Point &a, const Point &b)
{
return a.x*b.y-a.y*b.x;
}

// 两点距离平方
double disSq(const Point &a, const Point &b)
{
double dx=a.x-b.x, dy=a.y-b.y;
return dx*dx+dy*dy;
}

Point pivot; // 凸包基点
bool cmp(const Point &a, const Point &b)
{
double c=cross(a-pivot, b-pivot);
if(c==0) return disSq(a, pivot)<disSq(b, pivot);
return c>0;
}

vector<Point> convexHull(vector<Point> &p)
{
int n=p.size(), k=0;
if(n<=3) return p;
int minn=0;
for(int i=1; i<n; i++) if(p[i]<p[minn]) minn=i;
swap(p[0], p[minn]);
pivot=p[0];
sort(p.begin()+1, p.end(), cmp);
stack<Point> s;
s.push(p[0]), s.push(p[1]), s.push(p[2]);
for(int i=3; i<n; i++)
{
Point top=s.top(); s.pop();
while(cross(top-s.top(), p[i]-s.top())<=0)
{
top=s.top(); s.pop();
}
s.push(top); s.push(p[i]);
}
vector<Point> res;
while(!s.empty())
{
res.push_back(s.top());
s.pop();
}
reverse(res.begin(), res.end());
return res;
}

生成函数

生成函数是组合计数的强大工具,将数列 ${a_n}$ 编码为幂级数

普通生成函数(OGF)

常用公式

  • $\frac{1}{1-x}=1+x+x^2+\cdots=\sum_{n=0}^{\infty}x^n$
  • $\frac{1}{1-ax}=\sum_{n=0}^{\infty}a^nx^n$
  • $\frac{1}{(1-x)^2}=\sum_{n=0}^{\infty}(n+1)x^n$
  • $\frac{1}{(1-x)^k}=\sum{n=0}^{\infty}C{n+k-1}^{k-1}x^n$
  • $(1+x)^n=\sum_{k=0}^{n}C_n^kx^k$(二项式定理)
  • $\frac{1}{1-x-x^2}=\sum_{n=0}^{\infty}F_nx^n$(Fibonacci数)

运算

  • 加法:$(F+G)(x)=\sum_{n=0}^{\infty}(a_n+b_n)x^n$
  • 数乘:$(cF)(x)=\sum_{n=0}^{\infty}ca_nx^n$
  • 卷积:$(F\cdot G)(x)=\sum{n=0}^{\infty}\left(\sum{k=0}^{n}akb{n-k}\right)x^n$
  • 平移:$x^mF(x)=\sum_{n=0}^{\infty}a_nx^{n+m}$
  • 求导:$F’(x)=\sum_{n=1}^{\infty}na_nx^{n-1}$

指数生成函数(EGF)

常用公式

  • $e^x=\sum_{n=0}^{\infty}\frac{x^n}{n!}$
  • $e^{ax}=\sum_{n=0}^{\infty}\frac{a^n}{n!}x^n$
  • $\ln(1+x)=\sum_{n=1}^{\infty}\frac{(-1)^{n+1}}{n}x^n$
  • $\sin x=\sum_{n=0}^{\infty}\frac{(-1)^n}{(2n+1)!}x^{2n+1}$
  • $\cos x=\sum_{n=0}^{\infty}\frac{(-1)^n}{(2n)!}x^{2n}$

卷积:$(F\cdot G)(x)=\sum{n=0}^{\infty}\frac{1}{n!}\left(\sum{k=0}^{n}Cn^ka_kb{n-k}\right)x^n$

常见问题

1. 整数拆分:将 $n$ 拆分成若干正整数的方案数

2. 组合选取:从 $n$ 种物品中选取,每种可选 $0$ 到 $m_i$ 个

3. 错排问题:$n$ 个元素的错排数 $D_n$

生成函数:$\frac{e^{-x}}{1-x}$

4. 第二类斯特林数 $S(n,k)$:将 $n$ 个不同元素分成 $k$ 个非空集合

其中 $x^{\underline{k}}=x(x-1)(x-2)\cdots(x-k+1)$(下降阶乘幂)

代码(多项式运算)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
typedef vector<ll> poly;

// 多项式加法
poly add(poly a, poly b)
{
int n=max(a.size(), b.size());
poly c(n, 0);
for(int i=0; i<a.size(); i++) c[i]+=a[i];
for(int i=0; i<b.size(); i++) c[i]+=b[i];
return c;
}

// 多项式乘法O(n^2)
poly mul(poly a, poly b, ll mod)
{
int n=a.size(), m=b.size();
poly c(n+m-1, 0);
for(int i=0; i<n; i++)
for(int j=0; j<m; j++)
c[i+j]=(c[i+j]+a[i]*b[j])%mod;
return c;
}

// 快速幂(生成函数)
poly qpow(poly a, ll n, ll mod, int len)
{
poly res(len, 0);
res[0]=1;
while(n)
{
if(n&1) res=mul(res, a, mod);
a=mul(a, a, mod);
res.resize(len);
a.resize(len);
n>>=1;
}
return res;
}

// 计算 1/(1-x) 的前 n 项
poly inv_1_sub_x(int n, ll mod)
{
poly res(n);
for(int i=0; i<n; i++) res[i]=1;
return res;
}

// 二项式定理
poly binomial(ll n, int len, ll mod)
{
poly res(len);
res[0]=1;
for(int i=1; i<len; i++)
res[i]=res[i-1]*(n-i+1)%mod*qpow(i, mod-2, mod)%mod;
return res;
}

快速傅里叶变换(FFT)

用于多项式乘法 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
typedef complex<double> cd;
const double PI = acos(-1.0);
void FFT(vector<cd>& a, bool invert)
{
int n=a.size();
for(int i=1, j=0; i<n; i++)
{
int bit=n>>1;
for(; j&bit; bit>>=1) j^=bit;
j^=bit;
if(i<j) swap(a[i], a[j]);
}
for(int len=2; len<=n; len<<=1)
{
double ang=2*PI/len*(invert?-1:1);
cd wlen(cos(ang), sin(ang));
for(int i=0; i<n; i+=len)
{
cd w(1);
for(int j=0; j<len/2; j++)
{
cd u=a[i+j], v=w*a[i+j+len/2];
a[i+j]=u+v;
a[i+j+len/2]=u-v;
w*=wlen;
}
}
}
if(invert)
{
for(cd &x: a) x/=n;
}
}

// 多项式乘法
vector<int> multi(const vector<int> &a, const vector<int> &b)
{
vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
int n=1;
while(n<a.size()+b.size()) n<<=1;
fa.resize(n), fb.resize(n);
FFT(fa, false), FFT(fb, false);
for(int i=0; i<n; i++) fa[i]*=fb[i];
FFT(fa, true);
vector<int> res(n);
for(int i=0; i<n; i++) res[i]=round(fa[i].real());
// 如果是大数乘法,需要处理进位
// int carry=0;
// for(int i=0; i<n; i++)
// {
// res[i]+=carry;
// carry=res[i]/10;
// res[i]%=10;
// }
return res;
}

数论变换(NTT)

对于需要在模意义下计算的情况,则需要NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const int MOD = 998244353;
const int ROOT = 3;

ll qpow(ll b, ll exp, ll mod)
{
ll res=1;
b%=mod;
while(exp)
{
if(exp&1)
res=(res*b)%mod;
b=(b*b)%mod;
exp>>=1;
}
return res;
}

void NTT(vector<ll>& a, bool invert)
{
int n=a.size();
for(int i=1, j=0; i<n; i++)
{
int bit=n>>1;
for(; j&bit; bit>>=1) j^=bit;
j^=bit;
if(i<j) swap(a[i], a[j]);
}
for(int len=2; len<=n; len<<=1)
{
int wlen=qpow(ROOT, (MOD-1)/len, MOD);
if(invert) wlen=qpow(wlen, MOD-2, MOD);
for(int i=0; i<n; i+=len)
{
ll w=1;
for(int j=0; j<len/2; j++)
{
int u=a[i+j], v=w*a[i+j+len/2]%MOD;
a[i+j]=(u+v)%MOD;
a[i+j+len/2]=(u-v+MOD)%MOD;
w=(1LL*w*wlen)%MOD;
}
}
}
if(invert)
{
int nInv=qpow(n, MOD-2, MOD);
for(int &x: a) x=(1LL*x*nInv)%MOD;
}
}

拓扑排序

用于处理有向无环图(DAG)的排序问题,如:比赛排名、任务调度等

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
vector<int> topoSort(vector<vector<int> > &G)//G为邻接表
{
vector<int> topo;
vector<int> in(G.size(), 0); //记录每个节点的入度
stack<int> s; // 也可以使用priority_queue
for(int u=0; u<G.size(); u++)for(auto &v: G[u])in[v]++;
for(int u=0; u<G.size(); u++)if(in[u]==0)s.push(u);
while(!s.empty())
{
int u=s.top(); s.pop();
topo.push_back(u);
for(auto &v: G[u])
{
in[v]--;
if(in[v]==0)s.push(v);
}
}
return topo;
}

最短路

Dijkstra算法

用于求解单源最短路距离问题(非负权图),时间复杂度 $O(n\log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#define ll long long 
#define pii pair<int, int> // (dist, vertex) 方便使用优先队列排序
const ll INF=LLONG_MAX;

vector<ll> dijkstra(const vector<vector<pii> > &G, int s)
{
ll n=G.size();
vector<ll> dist(n, INF);
priority_queue<pii, vector<pii>, greater<pii> > pq;
dist[s]=0;
pq.push({0, s});
while(!pq.empty())
{
auto [d, u]=pq.top(); pq.pop();
if(d>dist[u]) continue;
for(auto [v, w]: G[u])if(dist[u]+w<dist[v])
{
dist[v]=dist[u]+w;
pq.push({dist[v], v});
}
}
return dist;
}

Bellman-Ford算法

适用于可能包含负权边的有向图或无向图的最短路问题,时间复杂度 $O(nm)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#define ll long long
const ll INF=LLONG_MAX;

struct Edge
{
int u, v, w;
Edge(int _u, int _v, int _w): u(_u), v(_v), w(_w) {}
};

/*
* @param s: 源点
* @param end: 终点
* @param E: 边集
* @param n: 点数
* @param dist: 最短路距离
* @param path: 最短路径
* @return: 是否存在负环
*/
bool bellmanFord(ll s, ll end, const vector<Edge> &E, ll n , vector<ll> &dist, vector<ll> &path)
{
vector<ll> pre(n, -1);
dist.assign(n, INF);
dist[s]=0;
for(ll i=0; i<n-1; i++)
{
bool flag=false;
for(const auto &e: E)
{
if(dist[e.u]!=INF && dist[e.u]+e.w<dist[e.v])
{
dist[e.v]=dist[e.u]+e.w;
pre[e.v]=e.u;
flag=true;
}
}
if(!flag) break;
}
// 判断是否存在负环
for(const auto &e: E)
if(dist[e.u]!=INF && dist[e.u]+e.w<dist[e.v]) return true;
if(dist[end]==INF) return false; // 不存在最短路

// 求最短路径
path.clear();
for(ll i=end; i!=-1; i=pre[i]) path.push_back(i);
reverse(path.begin(), path.end());

return false;
}

SPFA算法

该算法是Bellman-Ford算法的优化版本,时间复杂度 $O(km)$,其中 $k$ 是常数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#define ll long long 
const ll INF=LLONG_MAX;

struct Edge
{
int to, w;
Edge(int _to, int _w): to(_to), w(_w) {}
};

bool spfa(ll s, ll end, ll n, const vector<vector<Edge> > &G, vector<ll> &dist, vector<ll> &path)
{
vector<ll> pre(n, -1);
vector<int> cnt(n, 0);
vector<bool> inq(n, false);
dist.assign(n, INF);
dist[s]=0;
queue<int> q;
q.push(s);
inq[s]=true, cnt[s]=1;
while(!q.empty())
{
int u=q.front(); q.pop();
inq[u]=false;
for(const auto [v, w]: G[u])
{
if(dist[u]!=INF && dist[u]+w<dist[v])
{
dist[v]=dist[u]+w;
pre[v]=u;
if(!inq[v])
{
q.push(v);
inq[v]=true, cnt[v]++;
if(cnt[v]>n) return true; // 存在负环
}
}
}
}
if(dist[end]==INF) return false; // 不存在最短路

// 求最短路径
path.clear();
for(ll i=end; i!=-1; i=pre[i]) path.push_back(i);
reverse(path.begin(), path.end());

return false;
}

Floyd-Warshall 算法

用于求解所有顶点对之间的最短路径,适用于有向图或无向图,可以处理负权边(但不能有负权回路)。 时间复杂度 $O(n^3)$ , 仅适用于 $n \le 500$ 的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
const int INF=0x3f3f3f3f;
const int MAXN=505;

int n, dis[MAXN][MAXN];

void floyd()
{
for(int k=1; k<=n; k++)
{
for(int i=1; i<=n; i++)
{
for(int j=1; j<=n; j++) if(dis[i][k]!=INF && dis[k][j]!=INF)
dis[i][j]=min(dis[i][j], dis[i][k]+dis[k][j]);
}
}
}
// 初始化
void init()
{
for(int i=1; i<=n; i++)
{
for(int j=1; j<=n; j++)
{
if(i==j) dis[i][j]=0;
else dis[i][j]=INF;
}
}
}
// 检测负权回路
bool hasNegativeCycle()
{
for(int k=1; k<=n; k++)
if(dis[k][k]<0) return true;
return false;
}

如果需要重建最短路径,可以额外维护一个next矩阵,记录从ij的最短路径上i的后继节点。

最小生成树(MST)

Kruskal算法

时间复杂度 $O(m\log m)$,常用于稀疏图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
struct Edge
{
int u, v, w;
bool operator<(const Edge &e) const {return w<e.w;}
};

int fa[N];
int find(int x) {return fa[x]==x? x: fa[x]=find(fa[x]);}

ll kruskal(vector<Edge> &edges, int n)
{
sort(edges.begin(), edges.end());
for(int i=1; i<=n; i++) fa[i]=i;
ll sum=0;
int cnt=0;
for(auto [u, v, w]: edges)
{
int fu=find(u), fv=find(v);
if(fu!=fv)
{
fa[fu]=fv;
sum+=w;
cnt++;
if(cnt==n-1) break;
}
}
return cnt==n-1? sum: -1; // -1表示图不连通
}

Prim算法

时间复杂度 $O(m\log n)$,常用于稠密图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
ll prim(vector<vector<pii>> &g, int n)
{
vector<bool> vis(n+1, false);
priority_queue<pii, vector<pii>, greater<pii>> pq;
pq.push({0, 1}); // {权重, 节点}
ll sum=0;
int cnt=0;
while(!pq.empty())
{
auto [w, u]=pq.top(); pq.pop();
if(vis[u]) continue;
vis[u]=true;
sum+=w;
cnt++;
for(auto [v, wt]: g[u])
if(!vis[v]) pq.push({wt, v});
}
return cnt==n? sum: -1;
}

强连通分量(SCC)

Kosaraju算法

时间复杂度 $O(V+E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const int MAXN=100005;

vector<int> G[MAXN], G_[MAXN], comp_nodes[MAXN];
vector<int> order, comp(MAXN);
bool vis[MAXN];
int scc;

void dfs1(int u)
{
vis[u]=true;
for(auto v: G[u]) if(!vis[v]) dfs1(v);
order.push_back(u);
}

void dfs2(int u, int c)
{
comp[u]=c;
comp_nodes[c].push_back(u);
vis[u]=true;
for(auto v: G_[u]) if(!vis[v]) dfs2(v, c);
}

void kosaraju(int n)
{
order.clear();
scc=0;
memset(vis, false, sizeof(vis));
memset(comp, -1, sizeof(comp));
for(int i=1; i<=n; i++) if(!vis[i]) dfs1(i);
memset(vis, false, sizeof(vis));
for(int i=(int)order.size()-1; i>=0; i--) if(!vis[order[i]])
dfs2(order[i], ++scc);
}

vector<int> DAG[MAXN];
void buildDAG(int n)
{
for(int u=1; u<=n; u++)
{
for(auto v: G[u])if(comp[u]!=comp[v])
DAG[comp[u]].push_back(comp[v]);
}
// 去重边(可选)
for(int i=1; i<=scc; i++)
{
sort(DAG[i].begin(), DAG[i].end());
DAG[i].erase(unique(DAG[i].begin(), DAG[i].end()), DAG[i].end());
}
}

Tarjan算法

不仅可用于求强连通分量,还可用于求割点、桥等图论问题。时间复杂度 $O(V+E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
const int N=100005;
vector<int> G[N];
int dfn[N], low[N], comp[N], inSt[N];
stack<int> st;
int dfsti=0, scc=0;

/*
* @param u 当前节点
*/
void tarjan(int u)
{
dfn[u]=low[u]=++dfsti;
st.push(u); inSt[u]=1;
for(int v: G[u])
{
if(!dfn[v])
{
tarjan(v);
low[u]=min(low[u], low[v]);
}
else if(inSt[v]) low[u]=min(low[u], dfn[v]);
}
if(dfn[u]==low[u])
{
scc++;
while(1)
{
int x=st.top(); st.pop();
inSt[x]=0, comp[x]=scc;
if(x==u) break;
}
}
}
/*
使用方法:
for(int i=1; i<=n; i++) if(!dfn[i]) tarjan(i); // 遍历所有节点求解强连通分量
*/

欧拉回路/欧拉路径

欧拉回路:经过图中每条边恰好一次的回路。欧拉路径:经过图中每条边恰好一次的路径。

  • 无向图欧拉回路:连通且所有点度数为偶数
  • 无向图欧拉路径:连通且恰有0或2个奇度数点
  • 有向图欧拉回路:强连通且所有点入度=出度
  • 有向图欧拉路径:弱连通且恰有一个点出度-入度=1,一个点入度-出度=1,其余点入度=出度

Hierholzer算法(无向图)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
vector<pii> g[N]; // {邻接点, 边编号}
bool vis[M]; // 边是否被访问
vector<int> path;

void dfs(int u)
{
for(auto [v, id]: g[u])
{
if(vis[id]) continue;
vis[id]=true;
dfs(v);
}
path.push_back(u);
}

// 从起点st开始找欧拉路径
void euler(int st)
{
dfs(st);
reverse(path.begin(), path.end());
}

Hierholzer算法(有向图)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
vector<int> g[N];
int cur[N]; // 当前弧优化
vector<int> path;

void dfs(int u)
{
for(int &i=cur[u]; i<g[u].size(); i++)
{
int v=g[u][i];
dfs(v);
}
path.push_back(u);
}

void euler(int st)
{
dfs(st);
reverse(path.begin(), path.end());
}

2-SAT

给定 $n$ 个布尔变量和 $m$ 个形如”$x_i$ 为真或 $x_j$ 为假”的约束,判断是否有解。

对每个变量 $x_i$ 建立两个点 $i$(真) 和 $i’$(假)。约束”$x_i$ 或 $x_j$”转化为:$\neg x_i\Rightarrow x_j$ 和 $\neg x_j\Rightarrow x_i$。在图上跑强连通分量,若 $x_i$ 和 $x_i’$ 在同一SCC则无解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
const int N=2e5+5;
vector<int> g[N];
int n; // n个变量

// i表示xi为真,i+n表示xi为假
void addClause(int a, bool va, int b, bool vb)
{
// (a,va) OR (b,vb)
// !a => b, !b => a
g[a+(va?0:n)].push_back(b+(vb?n:0));
g[b+(vb?0:n)].push_back(a+(va?n:0));
}

// Tarjan求SCC
int dfn[N], low[N], comp[N], inSt[N];
stack<int> st;
int dfscnt=0, scc=0;

void tarjan(int u)
{
dfn[u]=low[u]=++dfscnt;
st.push(u); inSt[u]=1;
for(int v: g[u])
{
if(!dfn[v]) tarjan(v), low[u]=min(low[u], low[v]);
else if(inSt[v]) low[u]=min(low[u], dfn[v]);
}
if(dfn[u]==low[u])
{
scc++;
while(1)
{
int x=st.top(); st.pop();
inSt[x]=0; comp[x]=scc;
if(x==u) break;
}
}
}

bool solve()
{
for(int i=0; i<2*n; i++)
if(!dfn[i]) tarjan(i);

for(int i=0; i<n; i++)
if(comp[i]==comp[i+n]) return false;
return true;
}

/*
使用例子:
addClause(0, true, 1, false); // x0 OR !x1
addClause(1, true, 2, true); // x1 OR x2
if(solve()) cout<<"YES\n";
else cout<<"NO\n";

// 输出方案:comp[i] > comp[i+n] 则 xi=true
*/

差分约束

求解一组形如 $x_i-x_j\le c_k$ 的不等式组,判断是否有解。

思路:将不等式转化为图,$x_i-x_j\le c$ 建边 $j\to i$ 权值 $c$。求最短路,若存在负环则无解,否则 $d[i]$ 即为一组解。

求最大值(用最短路)

$x_i-x_j\le c$ 建边 $j\to i$ 权值 $c$,求最短路。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const int N=5005;
const ll INF=LLONG_MAX;
vector<pii> g[N]; // {v, w}
int n, m;

bool spfa(int st, vector<ll> &d)
{
vector<int> cnt(n+1, 0), inq(n+1, 0);
d.assign(n+1, INF);
queue<int> q;

// 超级源点,连接所有点
for(int i=1; i<=n; i++)
{
q.push(i);
d[i]=0;
inq[i]=1;
}

while(!q.empty())
{
int u=q.front(); q.pop();
inq[u]=0;
for(auto [v, w]: g[u])
{
if(d[v]>d[u]+w)
{
d[v]=d[u]+w;
if(!inq[v])
{
q.push(v);
inq[v]=1;
if(++cnt[v]>n) return false; // 负环
}
}
}
}
return true;
}

/*
使用:
// xi - xj <= c
g[j].push_back({i, c});

vector<ll> d;
if(spfa(1, d)) cout<<"有解\n";
else cout<<"无解\n";
*/

求最小值(用最长路)

$x_i-x_j\ge c$ 转化为 $x_j-x_i\le -c$,建边 $i\to j$ 权值 $-c$,求最长路。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
bool spfa_max(int st, vector<ll> &d)
{
vector<int> cnt(n+1, 0), inq(n+1, 0);
d.assign(n+1, -INF);
queue<int> q;
for(int i=1; i<=n; i++)
{
q.push(i);
d[i]=0;
inq[i]=1;
}

while(!q.empty())
{
int u=q.front(); q.pop();
inq[u]=0;
for(auto [v, w]: g[u])
{
if(d[v]<d[u]+w)
{
d[v]=d[u]+w;
if(!inq[v])
{
q.push(v);
inq[v]=1;
if(++cnt[v]>n) return false;
}
}
}
}
return true;
}

有向图缩点

将有向图的每个强连通分量缩成一个点,得到DAG(有向无环图)。

用Tarjan或Kosaraju求强连通分量,然后重新建图。缩点后常用于DP。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
const int N=1e4+5;
vector<int> g[N], dag[N];
int n, m, scc_cnt;
int dfn[N], low[N], comp[N], inSt[N];
stack<int> st;
int dfscnt=0;

void tarjan(int u)
{
dfn[u]=low[u]=++dfscnt;
st.push(u); inSt[u]=1;
for(int v: g[u])
{
if(!dfn[v]) tarjan(v), low[u]=min(low[u], low[v]);
else if(inSt[v]) low[u]=min(low[u], dfn[v]);
}
if(dfn[u]==low[u])
{
scc_cnt++;
while(1)
{
int x=st.top(); st.pop();
inSt[x]=0; comp[x]=scc_cnt;
if(x==u) break;
}
}
}

void rebuild()
{
for(int u=1; u<=n; u++)
for(int v: g[u])
if(comp[u]!=comp[v])
dag[comp[u]].push_back(comp[v]);

// 去重
for(int i=1; i<=scc_cnt; i++)
{
sort(dag[i].begin(), dag[i].end());
dag[i].erase(unique(dag[i].begin(), dag[i].end()), dag[i].end());
}
}

/*
使用:
for(int i=1; i<=n; i++)
if(!dfn[i]) tarjan(i);
rebuild();

// 缩点后DP示例:求DAG最长路
int dp[N];
int dfs_dag(int u)
{
if(dp[u]) return dp[u];
dp[u]=1;
for(int v: dag[u])
dp[u]=max(dp[u], dfs_dag(v)+1);
return dp[u];
}
*/

二分图

Kuhn-Munkres算法

匈牙利算法用于求解二分图的最大匹配问题,时间复杂度 $O(nm)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
const int MAXN=505;
int n,m; // 左右节点数
vector<int> G[MAXN];
vector<int> match; // match[v] 为与右部 v 匹配的左部点
vector<bool> vis;

bool dfs(int u)
{
for(int v: G[u])if(!vis[v])
{
vis[v]=true;
if(match[v]==-1 || dfs(match[v]))
{
match[v]=u;
return true;
}
}
return false;
}

int kuhn(int n)
{
int res=0;
match.assign(m+1, -1);
for(int i=1; i<=n; i++)
{
vis.assign(m+1, false);
if(dfs(i)) res++;
}
return res;
}

Hopcroft-Karp算法

用于求解大数据集下的二分图最大匹配问题,时间复杂度 $O(m\sqrt{n})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
const int MAXN=505;
vector<int> G[MAXN];
int matchL[MAXN], matchR[MAXN], dist[MAXN];
int n,m; // 左右节点数

bool bfs()
{
queue<int> q;
for(int u=1; u<=n; u++)
{
if(matchL[u]==-1)
{
dist[u]=0;
q.push(u);
}
else dist[u]=-1;
}
bool f=false;
while(!q.empty())
{
int u=q.front(); q.pop();
for(int v: G[u])
{
if(matchR[v]==-1) f=true;
else if(dist[matchR[v]]==-1)
{
dist[matchR[v]]=dist[u]+1;
q.push(matchR[v]);
}
}
}
return f;
}

bool dfs(int u)
{
for(int v: G[u])
{
if(matchR[v]==-1 || (dist[matchR[v]]==dist[u]+1 && dfs(matchR[v])))
{
matchL[u]=v, matchR[v]=u;
return true;
}
}
dist[u]=-1;
return false;
}

int hopcroft_karp()
{
fill(matchL, matchL+n+1, -1);
fill(matchR, matchR+m+1, -1);
int res=0;
while(bfs())
{
for(int u=1; u<=n; u++) if(matchL[u]==-1) res+=dfs(u);
}
return res;
}

二分图判定

用DFS染色法判断图是否为二分图,时间复杂度 $O(V+E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
int color[N];
vector<int> g[N];

bool dfs(int u, int c)
{
color[u]=c;
for(int v: g[u])
{
if(color[v]==c) return false;
if(color[v]==0 && !dfs(v, 3-c)) return false;
}
return true;
}

bool isBipartite(int n)
{
memset(color, 0, sizeof(color));
for(int i=1; i<=n; i++)
if(color[i]==0 && !dfs(i, 1))
return false;
return true;
}

网络流

Dinic算法

单源单汇最大流,时间复杂度 $O(V^2E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
struct Edge
{
int to, cap, flow;
};

vector<Edge> edges;
vector<int> g[N];
int d[N], cur[N];
int n, m, s, t; // s:源点, t:汇点

void addEdge(int u, int v, int cap)
{
edges.push_back({v, cap, 0});
edges.push_back({u, 0, 0});
g[u].push_back(edges.size()-2);
g[v].push_back(edges.size()-1);
}

bool bfs()
{
memset(d, 0, sizeof(d));
queue<int> q;
q.push(s);
d[s]=1;
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i: g[u])
{
Edge &e=edges[i];
if(!d[e.to] && e.cap>e.flow)
{
d[e.to]=d[u]+1;
q.push(e.to);
}
}
}
return d[t];
}

int dfs(int u, int a)
{
if(u==t || a==0) return a;
int flow=0, f;
for(int &i=cur[u]; i<g[u].size(); i++)
{
Edge &e=edges[g[u][i]];
if(d[e.to]==d[u]+1 && (f=dfs(e.to, min(a, e.cap-e.flow)))>0)
{
e.flow+=f;
edges[g[u][i]^1].flow-=f;
flow+=f;
a-=f;
if(a==0) break;
}
}
return flow;
}

int maxFlow()
{
int flow=0;
while(bfs())
{
memset(cur, 0, sizeof(cur));
flow+=dfs(s, INT_MAX);
}
return flow;
}

建模技巧

  • 多源多汇:建超级源 $S$ 和超级汇 $T$,$S \to$ 所有源,所有汇 $\to T$,容量 $\infty$
  • 点容量:拆点,$u$ 拆成 $u{in}$ 和 $u{out}$,连边容量为点容量
  • 无向边:两个方向各加一条边
  • 流量下界 $[L,R]$:边容量改为 $R-L$,$S\to v$ 容量 $L$,$u\to T$ 容量 $L$
  • 最小割:最大流 = 最小割,最后BFS后从源点可达的点为 $S$ 集合

最小费用最大流

在最大流的基础上,使得总费用最小。常用于带费用的匹配、运输等问题。

每次找费用最小的增广路(用SPFA),直到无法增广。时间复杂度 $O(nmf)$,其中 $f$ 为流量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
const int N=5005, M=50005;
const ll INF=1e18;

struct Edge
{
int to, nxt;
ll cap, cost;
}e[M<<1];

int head[N], cnt=1;
int n, m, s, t;
ll maxflow, mincost;

void addEdge(int u, int v, ll w, ll c)
{
e[++cnt]={v, head[u], w, c}; head[u]=cnt;
e[++cnt]={u, head[v], 0, -c}; head[v]=cnt;
}

ll dis[N], pre[N], flow[N];
bool inq[N];

bool spfa()
{
fill(dis, dis+n+1, INF);
memset(inq, 0, sizeof(inq));
queue<int> q;
q.push(s);
dis[s]=0; inq[s]=1;
flow[s]=INF;

while(!q.empty())
{
int u=q.front(); q.pop();
inq[u]=0;
for(int i=head[u]; i; i=e[i].nxt)
{
int v=e[i].to;
if(e[i].cap>0 && dis[v]>dis[u]+e[i].cost)
{
dis[v]=dis[u]+e[i].cost;
pre[v]=i;
flow[v]=min(flow[u], e[i].cap);
if(!inq[v]) q.push(v), inq[v]=1;
}
}
}
return dis[t]!=INF;
}

void mcmf()
{
maxflow=mincost=0;
while(spfa())
{
maxflow+=flow[t];
mincost+=flow[t]*dis[t];
for(int i=t; i!=s; i=e[pre[i]^1].to)
{
e[pre[i]].cap-=flow[t];
e[pre[i]^1].cap+=flow[t];
}
}
}

/*
使用:
addEdge(u, v, cap, cost); // u->v 容量cap 费用cost
mcmf();
cout<<maxflow<<" "<<mincost<<endl;
*/

割点和桥(Tarjan)

时间复杂度 $O(V+E)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
vector<int> g[N];
int dfn[N], low[N], dfscnt=0;
bool iscut[N]; // 是否为割点
vector<pii> bridges; // 桥

void tarjan(int u, int fa)
{
dfn[u]=low[u]=++dfscnt;
int child=0;
for(int v: g[u])
{
if(!dfn[v])
{
child++;
tarjan(v, u);
low[u]=min(low[u], low[v]);

// 判断割点
if(low[v]>=dfn[u] && fa!=-1) iscut[u]=true;

// 判断桥
if(low[v]>dfn[u]) bridges.push_back({u, v});
}
else if(v!=fa) low[u]=min(low[u], dfn[v]);
}

// 根节点的割点判断
if(fa==-1 && child>1) iscut[u]=true;
}

/*
使用:
memset(dfn, 0, sizeof(dfn));
memset(iscut, false, sizeof(iscut));
bridges.clear();
for(int i=1; i<=n; i++)
if(!dfn[i]) tarjan(i, -1);
*/

树的直径

树的直径是指树中任意两点间的最长路径。这里使用树形DP求解 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
const int MAXN=100005;
vector<int> tree[MAXN];
int dia=0;
int dfs(int u, int fa)
{
int max1=0, max2=0;
for(int v: tree[u])
{
if(v==fa) continue;
int d=dfs(v, u)+1;
if(d>max1) max2=max1, max1=d;
else if(d>max2) max2=d;
}
dia=max(dia, max1+max2);
return max1;
}

int treeDia(int n)
{
dia=0;
dfs(1, 0);
return dia;
}

倍增LCA

通过倍增预处理节点的 $2^k$ 个祖先,使得可以在时间复杂度 $O(\log n)$ 找到任意两个节点的最近公共祖先。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
const int MAXN=100005;
const int LOG=17; // log2(n)的向上取整

vector<int> tree[MAXN];
int dep[MAXN], fa[MAXN][LOG];

// BFS初始化深度和直接父节点
void bfs(int root, int n)
{
queue<int> q;
q.push(root);
dep[root]=1, fa[root][0]=-1;
while(!q.empty())
{
int u=q.front(); q.pop();
for(int v: tree[u])
{
if(v==fa[u][0]) continue;
dep[v]=dep[u]+1;
fa[v][0]=u;
q.push(v);
}
}
}


// 预处理倍增表
void init(int n)
{
bfs(1, n); // 根节点为1
for(int j=1; j<LOG; j++)
{
for(int i=1; i<=n; i++)
{
if(fa[i][j-1]==-1) fa[i][j]=-1;
else fa[i][j]=fa[fa[i][j-1]][j-1];
}
}
}

// 将节点u向上移动k步
int liftup(int u, int k)
{
for(int j=0; j<LOG; j++)
{
if(k&(1<<j))
{
u=fa[u][j];
if(u==-1) break;
}
}
return u;
}

// 找到u和v的最近公共祖先
int lca(int u, int v)
{
if(dep[u]<dep[v]) swap(u, v);
u=liftup(u, dep[u]-dep[v]);
if(u==v) return u;
for(int j=LOG-1; j>=0; j--)
if(fa[u][j]!=fa[v][j])
u=fa[u][j], v=fa[v][j];
return fa[u][0];
}

// 求两点间的距离
int dis(int u, int v)
{
int w=lca(u, v);
return dep[u]+dep[v]-2*dep[w];
}

// 使用之前记得init()

Tarjan算法(离线)

Tarjan算法(离线)可以高效地解决LCA的问题 $O(n+\alpha(n))$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
struct DSU
{
int n;
vector<int> fa, sz;

DSU(int _n)
{
n=_n;
fa.resize(n);
sz.resize(n, 1);
for(int i=0; i<n; i++) fa[i]=i;
}

int find(int x)
{
if(fa[x]!=x) fa[x]=find(fa[x]);
return fa[x];
}

void unite(int x, int y)
{
int ra=find(x), rb=find(y);
if(ra!=rb)
{
if(sz[ra]<sz[rb]) swap(ra, rb);
fa[rb]=ra;
sz[ra]+=sz[rb];
}
}
};

vector<int> tree[MAXN];
vector<pii> query[MAXN]; // query[u]=(v,query_id)
vector<int> ancestor, ans;
vector<bool> vis;

void tarjan(int u, int fa, DSU& dsu)
{
ancestor[u]=u;
for(int v: tree[u])
{
if(v==fa) continue;
tarjan(v, u, dsu);
dsu.unite(u, v);
ancestor[dsu.find(u)]=u;
}
vis[u]=true;
for(auto &[v, idx]: query[u]) if(vis[v])
ans[idx]=ancestor[dsu.find(v)];
}
/*
使用前记得初始化
DSU dsu(n+1);
ancestor.resize(n+1);
vis.resize(n+1,false);
ans.resize(t); // t个查询
//然后在query中加入查询
for(int i=0; i<m; i++)
query[u].push_back({v, i});
query[v].push_back({u, i});
//然后tarjan(1,0,dsu);
*/

字符串

STL的容器本身提供了一些字符串的算法:

  • unordered_map的原理是哈希表,可以用来快速查找字符串。
  • stringsubstr((int/size_t)pos, len) 函数可以用来截取子串。
  • stringfind()函数可以用来查找子串。找到,返回子串在原串中的起始位置(下标size_t), 否则返回string::npos
  • stringreplace(pos, len, str)函数可以用来替换子串,从pos开始长度为len的子串被替换为str
  • reverse()函数可以用来反转字符串。

KMP

KMP用于快速匹配字符串,时间复杂度 $O(n+m)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
vector<int> KMP(const string& s, const string& p)
{
// next array
int n=s.size(), m=p.size();
vector<int> nxt(m);
for(int i=1, j=0; i<m; i++)
{
while(j>0 && p[i]!=p[j]) j=nxt[j-1];
if(p[i]==p[j]) j++;
nxt[i]=j;
}
// search
vector<int> res;
for(int i=0, j=0; i<n; i++)
{
while(j>0 && s[i]!=p[j]) j=nxt[j-1];
if(s[i]==p[j]) j++;
if(j==m)
{
res.push_back(i-m+1);
j=nxt[j-1];
}
}
return res;
}
/*
How to ues:
string s="abcabcaaabcaabccbabca", p="abc";
auto pos=KMP(s, p);
for(auto i: pos) cout<<i<<' '; // output initial the match positions
*/

Z-Algorithm

Z-Algorithm用于快速找到字符串的所有子串与其自身匹配的位置,时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
vector<int> Z_Algo(const string& s)
{
int n=s.size();
vector<int> z(n, 0);
int l=0, r=0;
for(int i=1; i<n; i++)
{
if(i<=r) z[i]=min(r-i+1, z[i-l]);
while(i+z[i]<n && s[z[i]]==s[i+z[i]]) z[i]++;
if(i+z[i]-1>r) l=i, r=i+z[i]-1;
}
return z;
}

vector<int> find(const string& s, const string& p)
{
string t=p+'#'+s;
vector<int> z=Z_Algo(t);
vector<int> res;
int m=p.size();
for(int i=m+1; i<(int)t.size(); i++)
if(z[i]>=m) res.push_back(i-m-1);
return res;
}
/*
How to ues:
string s="abcabcaaabcaabccbabca", p="abc";
auto pos=find(s, p); // output initial the match positions
}

Trie树(字典树)

Trie树用于快速插入、查找、删除、前缀匹配单词,时间复杂度 $O(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
struct Trie
{
Trie* ch[26];
bool isEnd;
int cnt;
Trie(): isEnd(false), cnt(0)
{
memset(ch, 0, sizeof(ch));
}
};

void insert(Trie* root, const string& s)
{
Trie* p=root;
for(char c: s)
{
int idx=c-'a';
if(!p->ch[idx]) p->ch[idx]=new Trie();
p=p->ch[idx];
p->cnt++;
}
p->isEnd=true;
}

bool find(Trie* root, const string& s)
{
Trie* p=root;
for(char c: s)
{
int u=c-'a';
if(!p->ch[u]) return false;
p=p->ch[u];
}
return p->isEnd;
}

int prefixCnt(Trie* root, const string& s)
{
Trie* p=root;
for(char c: s)
{
int u=c-'a';
if(!p->ch[u]) return 0;
p=p->ch[u];
}
return p->cnt;
}

// 先检查find(s),再erase(s)
bool erase(Trie* root, const string& s)
{
if(!find(root, s)) return false;

Trie* p=root;
vector<Trie*> path;
for(char c: s)
{
int u=c-'a';
p=p->ch[u];
path.push_back(p);
}
p->isEnd=false;
for(int i=(int)s.size()-1; i>=0; i--) path[i]->cnt--;
return true;
}
/*
How to ues:
Trie* root=new Trie();
insert(root, "abc");
insert(root, "ab");
insert(root, "abcd");
cout<<find(root, "abc")<<endl; // output 1
cout<<prefixCnt(root, "ab")<<endl; // output 3
erase(root, "abcd");
cout<<find(root, "abcd")<<endl; // output 0
*/

Manacher算法

求字符串的最长回文子串,时间复杂度 $O(n)$。

通过在字符间插入特殊字符(如#)统一处理奇偶长度回文串。维护回文中心和右边界,利用对称性减少重复计算。$p[i]$ 表示以 $i$ 为中心的最长回文半径。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
string manacher(string s)
{
string t="#";
for(char c: s) t+=c, t+='#';
int n=t.size();
vector<int> p(n);
int mx=0, id=0; // mx为最右边界,id为对应中心

for(int i=0; i<n; i++)
{
p[i]=(i<mx)? min(p[2*id-i], mx-i): 1;
while(i-p[i]>=0 && i+p[i]<n && t[i-p[i]]==t[i+p[i]]) p[i]++;
if(i+p[i]>mx) mx=i+p[i], id=i;
}

// 找最长回文
int maxLen=0, center=0;
for(int i=0; i<n; i++)
{
if(p[i]-1>maxLen)
{
maxLen=p[i]-1;
center=i;
}
}

int start=(center-maxLen)/2;
return s.substr(start, maxLen);
}

// 返回所有位置的最长回文半径
vector<int> manacher_array(string s)
{
string t="#";
for(char c: s) t+=c, t+='#';
int n=t.size();
vector<int> p(n);
int mx=0, id=0;

for(int i=0; i<n; i++)
{
p[i]=(i<mx)? min(p[2*id-i], mx-i): 1;
while(i-p[i]>=0 && i+p[i]<n && t[i-p[i]]==t[i+p[i]]) p[i]++;
if(i+p[i]>mx) mx=i+p[i], id=i;
}

return p; // p[i]-1为原串中以i/2为中心的最长回文长度
}

AC自动机

多模式串匹配,在文本串中查找多个模式串的所有出现位置。时间复杂度 $O(n+m+z)$,$n$ 为文本长度,$m$ 为模式串总长度,$z$ 为匹配次数。

在Trie树基础上加上失配指针(类似KMP的next数组),失配时跳转到最长后缀对应节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
const int MAXN=1e6+5;

struct AC
{
int tr[MAXN][26], fail[MAXN], cnt[MAXN];
int idx;

void init()
{
idx=0;
memset(tr[0], 0, sizeof(tr[0]));
}

void insert(const string &s)
{
int p=0;
for(char c: s)
{
int u=c-'a';
if(!tr[p][u])
{
tr[p][u]=++idx;
memset(tr[idx], 0, sizeof(tr[idx]));
cnt[idx]=0;
}
p=tr[p][u];
}
cnt[p]++;
}

void build()
{
queue<int> q;
for(int i=0; i<26; i++)
if(tr[0][i]) q.push(tr[0][i]);

while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0; i<26; i++)
{
if(tr[u][i])
{
fail[tr[u][i]]=tr[fail[u]][i];
q.push(tr[u][i]);
}
else tr[u][i]=tr[fail[u]][i];
}
}
}

int query(const string &s)
{
int p=0, res=0;
for(char c: s)
{
p=tr[p][c-'a'];
for(int j=p; j && cnt[j]!=-1; j=fail[j])
{
res+=cnt[j];
cnt[j]=-1; // 标记已访问,避免重复计数
}
}
return res;
}
};

/*
使用:
AC ac;
ac.init();
ac.insert("she");
ac.insert("he");
ac.insert("her");
ac.build();
cout<<ac.query("sherhershe")<<endl; // 输出匹配次数
*/

后缀数组

对字符串的所有后缀进行排序,支持快速查询最长公共前缀(LCP)、子串查找等。时间复杂度 $O(n\log n)$。

倍增法构造后缀数组。$sa[i]$ 表示排名第 $i$ 的后缀的起始位置,$rk[i]$ 表示起始位置为 $i$ 的后缀的排名,$height[i]$ 表示 $sa[i]$ 和 $sa[i-1]$ 的最长公共前缀长度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
const int N=1e5+5;

struct SA
{
int n, m;
int sa[N], rk[N], oldrk[N], id[N], cnt[N];
int height[N]; // height[i] = lcp(sa[i], sa[i-1])

void build(const string &s)
{
n=s.size();
m=max(n, 300);

// 倍增构造SA
for(int i=1; i<=n; i++) cnt[rk[i]=s[i-1]]++;
for(int i=1; i<=m; i++) cnt[i]+=cnt[i-1];
for(int i=n; i>=1; i--) sa[cnt[rk[i]]--]=i;

for(int w=1; w<n; w<<=1)
{
int p=0;
for(int i=n; i>n-w; i--) id[++p]=i;
for(int i=1; i<=n; i++) if(sa[i]>w) id[++p]=sa[i]-w;

memset(cnt, 0, sizeof(cnt));
for(int i=1; i<=n; i++) cnt[rk[id[i]]]++;
for(int i=1; i<=m; i++) cnt[i]+=cnt[i-1];
for(int i=n; i>=1; i--) sa[cnt[rk[id[i]]]--]=id[i];

memcpy(oldrk, rk, sizeof(rk));
p=0;
for(int i=1; i<=n; i++)
{
if(oldrk[sa[i]]==oldrk[sa[i-1]] &&
oldrk[sa[i]+w]==oldrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p==n) break;
m=p;
}

// 构造height数组
int k=0;
for(int i=1; i<=n; i++)
{
if(rk[i]==1) continue;
if(k) k--;
int j=sa[rk[i]-1];
while(i+k<=n && j+k<=n && s[i+k-1]==s[j+k-1]) k++;
height[rk[i]]=k;
}
}

// 求lcp(后缀i, 后缀j)
int lcp(int i, int j)
{
if(i==j) return n-i+1;
if(rk[i]>rk[j]) swap(i, j);
int ans=INT_MAX;
for(int k=rk[i]+1; k<=rk[j]; k++)
ans=min(ans, height[k]);
return ans;
}
};

/*
使用:
SA sa;
string s="banana";
sa.build(s);

// sa.sa[i] 为排名第i的后缀起始位置
for(int i=1; i<=s.size(); i++)
cout<<s.substr(sa.sa[i]-1)<<endl;

// 求最长公共前缀
cout<<sa.lcp(1, 4)<<endl; // banana 和 ana 的LCP
*/

应用:最长重复子串

求字符串中最长的重复子串。

思路:最长重复子串即为任意两个后缀的最长公共前缀的最大值,即 $\max(height[i])$。

1
2
3
4
5
6
7
8
9
int cnt(const string &s)
{
SA sa;
sa.build(s);
int ans=0;
for(int i=2; i<=s.size(); i++)
ans=max(ans, sa.height[i]);
return ans;
}

应用:不同子串个数

求字符串中不同子串的个数。

思路:总子串数为 $\frac{n(n+1)}{2}$,减去重复的(即相邻后缀的LCP)。

1
2
3
4
5
6
7
8
9
10
ll cnt(const string &s)
{
SA sa;
sa.build(s);
int n=s.size();
ll ans=(ll)n*(n+1)/2;
for(int i=2; i<=n; i++)
ans-=sa.height[i];
return ans;
}

动态规划

背包问题

0/1背包

有 $n$ 种物品和一个容量为 $V$ 的背包,第 $i$ 种物品的体积为 $w_i$,价值为 $v_i$,每种物品只有一件,求解将哪些物品装入背包使总价值最大。

思路:$dp[i][j]$ 表示前 $i$ 件物品放入容量为 $j$ 的背包的最大价值。对于第 $i$ 件物品,要么不选,要么选,取最大值。空间优化后用一维数组,逆序枚举容量保证每件物品只用一次。

1
2
3
4
5
6
7
int n, V; // n个物品, 背包容量V
int w[N], v[N]; // 重量, 价值
int dp[N]; // dp[j]表示容量为j时的最大价值

for(int i=1; i<=n; i++)
for(int j=V; j>=w[i]; j--)
dp[j]=max(dp[j], dp[j-w[i]]+v[i]);

完全背包

有 $n$ 种物品和一个容量为 $V$ 的背包,第 $i$ 种物品的体积为 $w_i$,价值为 $v_i$,每种物品有无限件,求解将哪些物品装入背包使总价值最大。

思路:与0/1背包的区别是每种物品可以选无限次,正序枚举容量即可。

1
2
3
for(int i=1; i<=n; i++)
for(int j=w[i]; j<=V; j++)
dp[j]=max(dp[j], dp[j-w[i]]+v[i]);

多重背包

有 $n$ 种物品和一个容量为 $V$ 的背包,第 $i$ 种物品的体积为 $w_i$,价值为 $v_i$,数量为 $s_i$,求解将哪些物品装入背包使总价值最大。

思路:朴素做法是对每种物品枚举选几个,时间复杂度 $O(V\sum s_i)$。二进制优化:将 $s_i$ 个物品拆分成 $1,2,4,\ldots,2^k$ 和剩余部分,转化为0/1背包,时间复杂度 $O(V\sum\log s_i)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int s[N]; // 数量

// 二进制优化
for(int i=1; i<=n; i++)
{
int num=s[i];
for(int k=1; k<=num; k<<=1)
{
num-=k;
for(int j=V; j>=w[i]*k; j--)
dp[j]=max(dp[j], dp[j-w[i]*k]+v[i]*k);
}
if(num>0)
for(int j=V; j>=w[i]*num; j--)
dp[j]=max(dp[j], dp[j-w[i]*num]+v[i]*num);
}

混合背包

放入背包的物品可能只有1件(0/1背包),也可能有无限件(完全背包),也可能有可数的几件(多重背包)。

思路:分类讨论,根据物品类型选择对应的背包方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// s=-1:0/1背包, s=0:完全背包, s>0:多重背包
for(int i=1; i<=n; i++)
{
if(s[i]==-1) // 0/1
for(int j=V; j>=w[i]; j--)
dp[j]=max(dp[j], dp[j-w[i]]+v[i]);
else if(s[i]==0) // 完全
for(int j=w[i]; j<=V; j++)
dp[j]=max(dp[j], dp[j-w[i]]+v[i]);
else // 多重(二进制优化)
{
int num=s[i];
for(int k=1; k<=num; k<<=1)
{
num-=k;
for(int j=V; j>=w[i]*k; j--)
dp[j]=max(dp[j], dp[j-w[i]*k]+v[i]*k);
}
if(num>0)
for(int j=V; j>=w[i]*num; j--)
dp[j]=max(dp[j], dp[j-w[i]*num]+v[i]*num);
}
}

二维背包

背包有两个容量限制(如重量和体积),第 $i$ 种物品的重量为 $w_i$,体积为 $vol_i$,价值为 $v_i$。

思路:在0/1背包基础上,状态增加一维,枚举两个维度的容量。

1
2
3
4
5
6
7
8
int dp[N][N]; // dp[i][j]表示重量i、体积j时的最大价值
int w[N], v[N], vol[N];
int W, VOL;

for(int i=1; i<=n; i++)
for(int j=W; j>=w[i]; j--)
for(int k=VOL; k>=vol[i]; k--)
dp[j][k]=max(dp[j][k], dp[j-w[i]][k-vol[i]]+v[i]);

最长上升子序列(LIS)

给定长度为 $n$ 的序列,求严格递增的最长子序列长度。

思路:朴素DP为 $O(n^2)$。优化方法:维护数组 $d[i]$ 表示长度为 $i+1$ 的上升子序列末尾元素的最小值,该数组单调递增,可用二分查找优化到 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// O(nlogn) 版本
int LIS(vector<int> &a)
{
vector<int> d;
for(int x: a)
{
int pos=lower_bound(d.begin(), d.end(), x)-d.begin();
if(pos==d.size()) d.push_back(x);
else d[pos]=x;
}
return d.size();
}

// 最长不下降(允许相等):lower_bound改为upper_bound

最长公共子序列(LCS)

给定两个字符串 $a,b$,求它们的最长公共子序列长度。

思路

$dp[i][j]$ 表示 $a$ 的前 $i$ 个字符和 $b$ 的前 $j$ 个字符的LCS长度。

若 $a[i]=b[j]$ 则 $dp[i][j]=dp[i-1][j-1]+1$,否则 $dp[i][j]=\max(dp[i-1][j],dp[i][j-1])$。

1
2
3
4
5
6
7
8
9
10
11
12
int LCS(const string &a, const string &b)
{
int n=a.size(), m=b.size();
vector<vector<int>> dp(n+1, vector<int>(m+1, 0));
for(int i=1; i<=n; i++)
for(int j=1; j<=m; j++)
{
if(a[i-1]==b[j-1]) dp[i][j]=dp[i-1][j-1]+1;
else dp[i][j]=max(dp[i-1][j], dp[i][j-1]);
}
return dp[n][m];
}

状态压缩DP

用二进制表示状态集合,适用于 $n \le 20$ 的小规模问题。

思路:用整数的二进制位表示集合,第 $i$ 位为1表示第 $i$ 个元素在集合中。常用操作:S|(1<<i) 加入元素,S&(~(1<<i)) 删除元素,S&(1<<i) 判断是否在集合中。

旅行商问题(TSP)

有 $n$ 个城市,已知任意两城市间距离,从城市0出发访问所有城市恰好一次再回到0,求最短路径。

思路:$dp[S][i]$ 表示已访问城市集合为 $S$,当前在城市 $i$ 的最短路径长度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
int n, dist[N][N];
int dp[1<<N][N];

int TSP()
{
memset(dp, 0x3f, sizeof(dp));
dp[1][0]=0;
for(int S=0; S<(1<<n); S++)
for(int i=0; i<n; i++)
{
if(!(S&(1<<i))) continue;
for(int j=0; j<n; j++)
if(!(S&(1<<j)))
dp[S|(1<<j)][j]=min(dp[S|(1<<j)][j], dp[S][i]+dist[i][j]);
}
int ans=INT_MAX;
for(int i=0; i<n; i++)
ans=min(ans, dp[(1<<n)-1][i]+dist[i][0]);
return ans;
}

子集枚举

1
2
3
4
5
// 枚举S的所有子集(不含空集)
for(int sub=S; sub; sub=(sub-1)&S)
{
// 处理子集sub
}

状压DP - 棋盘放置

在 $n\times m$ 棋盘上放置方块,相邻格子不能同时放置,求方案数。

思路:$dp[i][S]$ 表示前 $i$ 行已放置,第 $i$ 行状态为 $S$ 的方案数。预处理合法状态(同一行不相邻),转移时检查相邻行不冲突。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int n, m;
ll dp[N][1<<M];
bool ok[1<<M];

for(int S=0; S<(1<<m); S++)
ok[S]=!((S&(S<<1))); // 预处理合法状态

dp[0][0]=1;
for(int i=0; i<n; i++)
for(int S=0; S<(1<<m); S++)
if(ok[S])
for(int nS=0; nS<(1<<m); nS++)
if(ok[nS] && !(S&nS))
dp[i+1][nS]+=dp[i][S];

数位DP

统计区间 $[L,R]$ 内满足某种数位条件的数的个数。

思路:记忆化搜索,按位枚举数字,维护状态和约束条件。limit 表示是否贴上界,lead 表示是否有前导零。答案为 $solve(R)-solve(L-1)$。

例:不含连续49的数的个数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
ll dp[20][2];
bool vis[20][2];
int a[20];

ll dfs(int pos, int pre, bool limit)
{
if(pos==0) return 1;
if(!limit && vis[pos][pre]) return dp[pos][pre];

int up=limit? a[pos]: 9;
ll res=0;
for(int i=0; i<=up; i++)
{
if(pre==1 && i==9) continue; // 上一位是4,这一位是9,跳过
res+=dfs(pos-1, i==4, limit&&i==up);
}
if(!limit) vis[pos][pre]=true, dp[pos][pre]=res;
return res;
}

ll solve(ll x)
{
int len=0;
while(x) a[++len]=x%10, x/=10;
memset(vis, false, sizeof(vis));
return dfs(len, 0, true);
}
// 答案:solve(R) - solve(L-1)

树形DP

树的最大独立集

在树上选择若干个不相邻的节点,使权值和最大。

思路:$dp[u][0]$ 表示不选 $u$ 的最大权值,$dp[u][1]$ 表示选 $u$ 的最大权值。若选 $u$ 则子节点都不能选。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int w[N], dp[N][2];

void dfs(int u, int fa)
{
dp[u][0]=0; dp[u][1]=w[u];
for(int v: tree[u])
{
if(v==fa) continue;
dfs(v, u);
dp[u][0]+=max(dp[v][0], dp[v][1]);
dp[u][1]+=dp[v][0];
}
}
// 答案:max(dp[1][0], dp[1][1])

树上背包

在树上选 $k$ 个节点使权值和最大。

思路:$dp[u][j]$ 表示 $u$ 子树中选 $j$ 个节点的最大权值。合并子树时类似背包DP。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int dp[N][N], sz[N], k;

void dfs(int u, int fa)
{
sz[u]=1; dp[u][1]=w[u];
for(int v: tree[u])
{
if(v==fa) continue;
dfs(v, u);
for(int i=min(k, sz[u]); i>=1; i--)
for(int j=1; j<=min(k-i, sz[v]); j++)
dp[u][i+j]=max(dp[u][i+j], dp[u][i]+dp[v][j]);
sz[u]+=sz[v];
}
}

区间DP

石子合并

有 $n$ 堆石子排成一行,每次可以合并相邻的两堆,代价为两堆石子数量之和,求合并成一堆的最小代价。

思路:$dp[i][j]$ 表示合并区间 $[i,j]$ 的最小代价。枚举分割点 $k$,$dp[i][j]=\min(dp[i][k]+dp[k+1][j]+sum[i,j])$。

1
2
3
4
5
6
7
8
9
10
11
12
13
int a[N], sum[N], dp[N][N];

for(int i=1; i<=n; i++) sum[i]=sum[i-1]+a[i];

for(int len=2; len<=n; len++)
for(int i=1; i+len-1<=n; i++)
{
int j=i+len-1;
dp[i][j]=INT_MAX;
for(int k=i; k<j; k++)
dp[i][j]=min(dp[i][j], dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1]);
}
// 答案:dp[1][n]

环形石子合并

石子排成一圈,求最小代价。

思路:破环成链,复制一份接在后面,枚举起点,取长度为 $n$ 的区间的最小值。

1
2
3
4
5
6
for(int i=1; i<=n; i++) a[i+n]=a[i];
for(int i=1; i<=2*n; i++) sum[i]=sum[i-1]+a[i];
// 同上DP,枚举len从2到n
int ans=INT_MAX;
for(int i=1; i<=n; i++)
ans=min(ans, dp[i][i+n-1]);

其他经典DP

最大子段和(Kadane算法)

给定数组,求最大连续子段和。

思路:$dp[i]$ 表示以 $i$ 结尾的最大子段和,$dp[i]=\max(a[i], dp[i-1]+a[i])$。

1
2
3
4
5
6
7
8
9
10
ll maxSubArray(vector<int> &a)
{
ll ans=LLONG_MIN, sum=0;
for(int x: a)
{
sum=max((ll)x, sum+x);
ans=max(ans, sum);
}
return ans;
}

最大子矩阵和

给定二维矩阵,求最大子矩阵和。

思路:枚举上下边界,将列压缩成一维数组,转化为最大子段和问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
ll maxSubMatrix(vector<vector<int>> &mat)
{
int n=mat.size(), m=mat[0].size();
ll ans=LLONG_MIN;
for(int up=0; up<n; up++)
{
vector<ll> sum(m, 0);
for(int down=up; down<n; down++)
{
for(int j=0; j<m; j++) sum[j]+=mat[down][j];
ll cur=0;
for(int j=0; j<m; j++)
{
cur=max(sum[j], cur+sum[j]);
ans=max(ans, cur);
}
}
}
return ans;
}

编辑距离

将字符串 $a$ 变成 $b$,可以插入、删除、替换字符,求最少操作次数。

思路:$dp[i][j]$ 表示 $a$ 的前 $i$ 个字符变成 $b$ 的前 $j$ 个字符的最少操作数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int editDistance(const string &a, const string &b)
{
int n=a.size(), m=b.size();
vector<vector<int>> dp(n+1, vector<int>(m+1));
for(int i=0; i<=n; i++) dp[i][0]=i;
for(int j=0; j<=m; j++) dp[0][j]=j;
for(int i=1; i<=n; i++)
for(int j=1; j<=m; j++)
{
if(a[i-1]==b[j-1]) dp[i][j]=dp[i-1][j-1];
else dp[i][j]=min({dp[i-1][j], dp[i][j-1], dp[i-1][j-1]})+1;
}
return dp[n][m];
}

优化技巧

有的时候实在想不到更优的方法,可以尝试将暴力方法优化。

离散化

将大范围的数据映射到小范围,常用于值域很大但数据量较小的情况。

将数值映射到 $[1,n]$ 的连续整数,保持相对大小关系不变。

1
2
3
4
5
6
7
8
9
10
11
12
13
// 基本离散化
vector<int> a; // 原始数据
vector<int> b=a; // 备份用于映射

sort(b.begin(), b.end());
b.erase(unique(b.begin(), b.end()), b.end()); // 去重

// 离散化:将a[i]映射到[1,len]
for(int i=0; i<a.size(); i++)
a[i]=lower_bound(b.begin(), b.end(), a[i])-b.begin()+1;

// 还原:从离散化值还原为原值
int original_val=b[a[i]-1];

区间离散化

处理区间问题时,需要离散化区间端点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
vector<pii> segs; // 区间 {l, r}
vector<int> nums;

for(auto [l, r]: segs)
{
nums.push_back(l);
nums.push_back(r);
}

sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());

// 查询x离散化后的值
auto getID = [&](int x) {
return lower_bound(nums.begin(), nums.end(), x)-nums.begin()+1;
};

莫队算法

离线处理区间查询问题,时间复杂度 $O(n\sqrt{n})$。适用于可增量维护的区间信息。

将询问按照左端点所在块排序,相同块内按右端点排序。通过移动左右指针,增量维护答案。

普通莫队

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
int n, m, blockSize;
int a[N], cnt[N], ans[N];
int curAns=0;

struct Query
{
int l, r, id;
bool operator<(const Query &q) const
{
if(l/blockSize!=q.l/blockSize) return l<q.l;
return (l/blockSize)&1? r<q.r: r>q.r; // 奇偶优化
}
}q[N];

void add(int pos)
{
// 加入位置pos的元素
if(++cnt[a[pos]]==1) curAns++;
}

void del(int pos)
{
// 删除位置pos的元素
if(--cnt[a[pos]]==0) curAns--;
}

void solve()
{
blockSize=sqrt(n);
sort(q, q+m);

int l=1, r=0;
for(int i=0; i<m; i++)
{
while(l>q[i].l) add(--l);
while(r<q[i].r) add(++r);
while(l<q[i].l) del(l++);
while(r>q[i].r) del(r--);
ans[q[i].id]=curAns;
}
}

带修改莫队

支持单点修改的莫队,时间复杂度 $O(n^{5/3})$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
int n, m, qcnt, mcnt, blockSize;
int a[N], cnt[N], ans[N];
int curAns=0;

struct Query
{
int l, r, t, id; // t为时间戳
bool operator<(const Query &q) const
{
if(l/blockSize!=q.l/blockSize) return l<q.l;
if(r/blockSize!=q.r/blockSize) return r<q.r;
return t<q.t;
}
}queries[N];

struct Modify
{
int pos, val; // 将pos位置修改为val
}modifies[N];

void add(int pos) { if(++cnt[a[pos]]==1) curAns++; }
void del(int pos) { if(--cnt[a[pos]]==0) curAns--; }

void applyModify(int i, int l, int r)
{
if(l<=modifies[i].pos && modifies[i].pos<=r)
{
del(modifies[i].pos);
swap(a[modifies[i].pos], modifies[i].val);
add(modifies[i].pos);
}
else swap(a[modifies[i].pos], modifies[i].val);
}

void solve()
{
blockSize=pow(n, 2.0/3);
sort(queries, queries+qcnt);

int l=1, r=0, t=0;
for(int i=0; i<qcnt; i++)
{
while(l>queries[i].l) add(--l);
while(r<queries[i].r) add(++r);
while(l<queries[i].l) del(l++);
while(r>queries[i].r) del(r--);
while(t<queries[i].t) applyModify(++t, l, r);
while(t>queries[i].t) applyModify(t--, l, r);
ans[queries[i].id]=curAns;
}
}

CDQ分治

处理三维偏序、动态规划优化等问题,时间复杂度 $O(n\log^2 n)$。

核心思想:分治时只统计左半部分对右半部分的贡献。

三维偏序问题

给定 $n$ 个三元组 $(a_i,b_i,c_i)$,求每个点的支配数(有多少个点三维都 $\le$ 它)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
const int N=1e5+5;

struct Point
{
int a, b, c, ans, cnt;
bool operator==(const Point &p) const
{
return a==p.a && b==p.b && c==p.c;
}
}p[N], tmp[N];

int n, bit[N], maxc;

void add(int x, int v)
{
for(; x<=maxc; x+=x&-x) bit[x]+=v;
}

int query(int x)
{
int res=0;
for(; x; x-=x&-x) res+=bit[x];
return res;
}

void cdq(int l, int r)
{
if(l==r) return;
int mid=(l+r)>>1;
cdq(l, mid); cdq(mid+1, r);

// 归并排序,按b排序
int i=l, j=mid+1, k=l;
while(i<=mid && j<=r)
{
if(p[i].b<=p[j].b)
{
add(p[i].c, p[i].cnt);
tmp[k++]=p[i++];
}
else
{
p[j].ans+=query(p[j].c);
tmp[k++]=p[j++];
}
}
while(j<=r) p[j].ans+=query(p[j].c), tmp[k++]=p[j++];
for(int t=l; t<i; t++) add(p[t].c, -p[t].cnt);
while(i<=mid) tmp[k++]=p[i++];
for(int t=l; t<=r; t++) p[t]=tmp[t];
}

/*
使用:
1. 先对a排序
2. 相同的点合并,cnt记录个数
3. cdq(1, n)
4. p[i].ans为答案
*/

点分治

处理树上路径问题(如路径长度、路径点数等),时间复杂度 $O(n\log n)$。

核心思想:找树的重心作为根,统计经过根的路径,然后递归处理子树。

树上路径统计

求树上距离 $\le k$ 的点对数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
const int N=1e4+5;
vector<pii> g[N]; // {邻接点, 边权}
int n, k;
bool vis[N];
int sz[N], maxsz[N], sum;
int root;

// 找重心
void getRoot(int u, int fa)
{
sz[u]=1; maxsz[u]=0;
for(auto [v, w]: g[u])
{
if(v==fa || vis[v]) continue;
getRoot(v, u);
sz[u]+=sz[v];
maxsz[u]=max(maxsz[u], sz[v]);
}
maxsz[u]=max(maxsz[u], sum-sz[u]);
if(maxsz[u]<maxsz[root]) root=u;
}

vector<int> dist;

void getDist(int u, int fa, int d)
{
dist.push_back(d);
for(auto [v, w]: g[u])
{
if(v==fa || vis[v]) continue;
getDist(v, u, d+w);
}
}

int calc(int u, int init)
{
dist.clear();
getDist(u, 0, init);
sort(dist.begin(), dist.end());

int res=0;
int l=0, r=dist.size()-1;
while(l<r)
{
if(dist[l]+dist[r]<=k) res+=r-l, l++;
else r--;
}
return res;
}

int ans=0;

void solve(int u)
{
vis[u]=true;
ans+=calc(u, 0);

for(auto [v, w]: g[u])
{
if(vis[v]) continue;
ans-=calc(v, w); // 减去同一子树内的路径
sum=sz[v];
root=0; maxsz[0]=INT_MAX;
getRoot(v, 0);
solve(root);
}
}

/*
使用:
sum=n;
root=0; maxsz[0]=INT_MAX;
getRoot(1, 0);
solve(root);
cout<<ans<<endl;
*/

点分治模板(通用)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
int n, k, ans;
vector<pii> g[N];
bool vis[N];
int sz[N], maxsz[N], root;

void getRoot(int u, int fa, int total)
{
sz[u]=1; maxsz[u]=0;
for(auto [v, w]: g[u])
{
if(v==fa || vis[v]) continue;
getRoot(v, u, total);
sz[u]+=sz[v];
maxsz[u]=max(maxsz[u], sz[v]);
}
maxsz[u]=max(maxsz[u], total-sz[u]);
if(maxsz[u]<maxsz[root]) root=u;
}

vector<int> dist;
void getDist(int u, int fa, int d)
{
dist.push_back(d);
for(auto [v, w]: g[u])
if(v!=fa && !vis[v])
getDist(v, u, d+w);
}

// 统计以u为根,经过u的路径
int calc(int u)
{
// 具体问题具体实现
return 0;
}

void divide(int u)
{
vis[u]=true;
ans+=calc(u);

for(auto [v, w]: g[u])
{
if(vis[v]) continue;
root=0; maxsz[0]=INT_MAX;
getRoot(v, 0, sz[v]);
divide(root);
}
}

void pointDivide()
{
root=0; maxsz[0]=INT_MAX;
getRoot(1, 0, n);
divide(root);
}

STL

容器

tuple

元组,可以存储不同类型的元素。

1
2
tuple<string, int, int> Student = {"Bob", 18, 213213};
cout<<get<0>(Student)<<endl; // 输出 "Bob"

取对象中的第index个元素get<index>(obj)

注意:这里的index只能手动输入,使用for循环这样的自动输入是不可以的

array

固定大小的数组,存储在栈上。

1
2
array<int, 3> x; // 建立一个包含三个元素的数组x
cout<<x[0]; // 使用[]随机访问
  • size() / empty()
  • [] 随机访问
  • front() / back() 获取首尾元素
  • fill(val) 填充所有元素为val

vector

动态数组,最常用的容器。

声明方式

1
2
3
4
5
6
7
// 不推荐:使用[]声明多维变长数组
vector<int> ver[n+1]; // ✖

// 推荐:使用嵌套方式
vector<vector<int>> ver(n+1,0); // ✓
vector dis(n+1, vector<int>(m+1)); // 二维
vector dis(m+1, vector(n+1, vector<int>(k+1))); // 三维

常用函数

  • size() / empty() / clear()
  • resize(n) 重设容器大小,但不改变已有元素的值
  • assign(n, val) 重设容器大小为n,且替换容器内容为val
  • push_back(x) / pop_back() 尾部插入/删除
  • front() / back() 获取首尾元素
  • begin() / end() 迭代器
  • [] 随机访问
  • insert(it, x) 在迭代器it处插入x
  • erase(it) / erase(first, last) 删除元素

stack

栈,栈顶入,栈顶出。先进后出

  • 没有clear函数
  • size() / empty()
  • push(x) 向栈顶插入x
  • top() 获取栈顶元素
  • pop() 弹出栈顶元素

queue

队列,队尾进,队头出。先进先出

  • 没有clear函数,但可以用重新构造替代:
    1
    2
    queue<int> q;
    q = queue<int>();
  • size() / empty()
  • push(x) 向队尾插入x
  • front() / back() 获取队头、队尾元素
  • pop() 弹出队头元素

deque

双向队列,两端都可以插入和删除。

  • size() / empty() / clear()
  • push_front(x) / push_back(x) 首尾插入
  • pop_front() / pop_back() 首尾删除
  • front() / back() 获取首尾元素
  • begin() / end() 迭代器
  • [] 随机访问

priority_queue

优先队列,默认升序(大根堆)

1
2
3
4
5
// 默认大根堆,堆顶最大值
priority_queue<int> pq;

// 小根堆
priority_queue<int, vector<int>, greater<int>> pq;
  • 没有clear函数
  • push(x) 插入元素
  • top() 获取堆顶元素
  • pop() 弹出堆顶元素
  • size() / empty()

自定义排序(重载运算符)

注意:符号相反!!!

1
2
3
4
5
6
7
8
struct Node {
int x;
string s;
friend bool operator < (const Node &a, const Node &b) {
if(a.x != b.x) return a.x > b.x; // 想要小的在前,这里写>
return a.s > b.s;
}
};

string

字符串容器。

  • size() / empty() / clear()
  • push_back(c) / pop_back() 尾部插入/删除字符
  • [] 随机访问
  • substr(start, len) 从start开始取长度为len的子串
    • len省略时默认取到结尾
    • 超过字符串长度时也默认取到结尾
      1
      cout<<S.substr(1, 12);
  • find(x) / rfind(x) 顺序、逆序查找x,返回下标
    • 没找到时返回一个极大值
    • 建议与 size() 比较,而不要和 -1 比较,后者可能出错
      1
      if(s.find("abc") < s.size()) // ✓
  • 没有count函数
  • append(str) / += 追加字符串
  • compare(str) 字符串比较
  • replace(pos, len, str) 替换子串

set / multiset

有序集合,默认升序,时间复杂度 $O(\log n)$

  • set 去重
  • multiset 不去重

常用函数

  • size() / empty() / clear()
  • insert(x) 插入元素
  • erase(x) 两种删除方式
    • 当x为某一元素时,删除所有这个数,复杂度 $O(k+\log n)$,k为删除个数
    • 当x为迭代器时,删除这个迭代器,复杂度 $O(1)$
  • find(x) 查找元素,返回迭代器,没找到返回end()
  • count(x) 统计元素个数
  • lower_bound(x) 返回 $\ge x$ 的第一个元素的迭代器
  • upper_bound(x) 返回 $> x$ 的第一个元素的迭代器
  • begin() / end() 迭代器
  • rbegin() / rend() 反向迭代器

特殊函数 next 和 prev

1
2
3
4
set<int> s = {1, 3, 5, 7};
auto it = s.find(3);
auto next_it = next(it); // 指向5
auto prev_it = prev(it); // 指向1

map / multimap

键值对容器,默认按键升序,时间复杂度 $O(\log n)$,$n$ 为元素数量。

  • map 键去重
  • multimap 键不去重

常用函数

  • size() / empty() / clear()
  • insert({key, value}) 插入键值对
  • erase(x) 两种删除方式
    • 当x为某一元素时,删除所有以这个元素为下标的二元组,复杂度 $O(k+\log n)$
    • 当x为迭代器时,删除这个迭代器,复杂度 $O(1)$
  • find(key) 查找键,返回迭代器
  • count(key) 统计键的个数
  • [] 访问/修改值,若键不存在会自动创建
  • begin() / end() 迭代器

慎用随机访问! 当不确定某次查询是否存在于容器中时,不要直接使用下标查询,而是先使用 count() 或者 find() 方法检查key值,防止不必要的零值二元组被构造。

1
2
3
4
5
6
// 不好的做法✖
if(mp[key]) // 若key不存在,会创建mp[key]=0

// 推荐做法✓
if(mp.count(key))
if(mp.find(key) != mp.end())

慎用自带的 pair、tuple 作为key值类型!使用自定义结构体!

unordered_set / unordered_map

无序集合/映射,基于哈希表,平均时间复杂度 $O(1)$,最坏 $O(n)$。

常用函数:与 set / map 类似,但没有 lower_bound / upper_bound

  • size() / empty() / clear()
  • insert(x) / erase(x) / find(x) / count(x)
  • [] (仅unordered_map)

bitset

位集,将数据转换为二进制,从高位到低位排序,以 $0$ 为最低位。当位数相同时支持全部的位运算。

1
2
3
bitset<10> bs; // 10位bitset
bitset<10> bs(5); // 二进制 0000000101
bitset<10> bs("1010"); // 二进制 0000001010

常用函数

  • 没有clear函数
  • size() / empty()
  • count() 统计1的个数
  • any() 是否有1
  • none() 是否全为0
  • all() 是否全为1
  • set() 全部置1
  • set(pos) 将pos位置1
  • set(pos, val) 将pos位设为val
  • reset() 全部置0
  • reset(pos) 将pos位置0
  • flip() 全部翻转
  • flip(pos) 翻转pos位
  • test(pos) 返回pos位的值
  • [] 访问某一位
  • to_string() 转换为字符串
  • to_ulong() / to_ullong() 转换为整数

位运算

1
2
3
4
5
6
7
bitset<4> bs1("1010"), bs2("1100");
cout << (bs1 & bs2); // 1000 与
cout << (bs1 | bs2); // 1110 或
cout << (bs1 ^ bs2); // 0110 异或
cout << (~bs1); // 0101 取反
cout << (bs1 << 1); // 0100 左移
cout << (bs1 >> 1); // 0101 右移

常用算法函数

sort

排序,时间复杂度 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
sort(a, a+n); // 数组排序
sort(v.begin(), v.end()); // vector排序

// 降序
sort(v.begin(), v.end(), greater<int>());

// 自定义比较函数
bool cmp(int a, int b) { return a > b; }
sort(v.begin(), v.end(), cmp);

// lambda表达式
sort(v.begin(), v.end(), [](int a, int b) { return a > b; });

reverse

反转容器。

1
2
reverse(a, a+n);
reverse(v.begin(), v.end());

unique

去重(需要先排序),返回去重后的尾迭代器。

1
2
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());

lower_bound / upper_bound

二分查找(需要有序)。

1
2
3
4
5
6
7
8
// lower_bound: 返回 >= x 的第一个位置(迭代器)
auto it = lower_bound(v.begin(), v.end(), x);

// upper_bound: 返回 > x 的第一个位置(迭代器)
auto it = upper_bound(v.begin(), v.end(), x);

// 数组
int pos = lower_bound(a, a+n, x) - a;

min / max / swap

1
2
3
4
int mn = min(a, b);
int mx = max(a, b);
int mn = min({a, b, c});
swap(a, b);

next_permutation / prev_permutation

全排列

next_permutation是生成下一个排列,即字典序中下一个比当前排列大的排列(升序->降序)
prev_permutation是生成上一个排列,即字典序中上一个比当前排列小的排列(降序->升序)

1
2
3
4
vector<int> v = {1, 2, 3};
do {
// 处理当前排列
} while(next_permutation(v.begin(), v.end()));

accumulate

求和,需要 #include <numeric>

1
2
int sum = accumulate(v.begin(), v.end(), 0);
ll sum = accumulate(v.begin(), v.end(), 0LL); // 注意初始值类型

fill / memset

1
2
3
4
fill(a, a+n, 0); // 任意值
fill(v.begin(), v.end(), 0);

memset(a, 0, sizeof(a)); // 只能用于0和-1

to_string / stoi / stoll

字符串和数字之间的转换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 数字转字符串
string s = to_string(123);
string s2 = to_string(3.14);

// 字符串转整数
int a = stoi("123");
long long b = stoll("123456789012345");

// 支持不同进制
int hex = stoi("1A", nullptr, 16); // 26
int bin = stoi("1010", nullptr, 2); // 10

// 注意:stoi会忽略前导空格,遇到非数字字符停止
int c = stoi(" 123abc"); // 123

注意:

  • stoi是转换为i32stoll是转换为i64
  • 超出范围会抛出 out_of_range 异常
  • 格式错误会抛出 invalid_argument 异常