传送门:https://pintia.cn/problem-sets/994805046380707840/exam/problems/1518582895035215872?type=7&page=1
思路
观察发现,逆序对可以分成两类:
- 节点 u u u 和 v v v 有明确的父子关系(不一定是直属的直连边,可能是儿子往上跳很多条边才到达父亲),假设 u u u 是 v v v 的父亲,那么不管怎么遍历, u u u 一定会比 v v v 更早遍历到,因为 v v v 在 u u u 的子树里;那么这种情况下,如果 u > v u > v u>v 的话,就会产生一个逆序对
- 节点 u u u 和 v v v 没有明确的父子关系,那么假设它们的 L C A LCA LCA 为 L L L, u u u 和 v v v 的出现次序取决于 L L L 的遍历顺序,因此这一部分的逆序对数量显然和 L L L 的直属儿子数 的 全排列 有关
对于第一部分的答案,我们很容易就可以统计,进入一个点 u u u 时,我们在 u u u 这个位置打上标记,那么之后所有在 u u u 的子树中的点 v v v,都会被这个标记影响,对于当前点 v v v,要统计其有多少个父亲的点权大于它(产生逆序对),只需要使用树状数组查询 [ v + 1 , n ] [v + 1, n] [v+1,n] 的和即可;离开 u u u 时,我们删除这个标记即可
假设我们求出来第一部分产生的逆序对为
c
n
t
cnt
cnt 个,不管
d
f
s
dfs
dfs 序如果变化,这第一部分产生的逆序对是恒定不变的
一共有
∏
s
o
n
i
!
\prod son_i !
∏soni! 个
d
f
s
dfs
dfs 序(
s
o
n
i
son_i
soni 表示节点
i
i
i 的直属儿子数量),
s
o
n
i
>
0
son_i > 0
soni>0
所以第一部分的答案是:
c
n
t
×
∏
s
o
n
i
!
cnt \times \prod son_i !
cnt×∏soni!
对于第二部分的答案,我们对于当前节点
u
u
u,假设
s
z
[
v
]
sz[v]
sz[v] 为点
v
v
v 的子树中的节点数量,那么 子树
u
u
u 中,以
u
u
u 为
L
C
A
LCA
LCA 的点对两两配对的方案数有:
s
z
[
v
1
]
×
s
z
[
v
2
]
+
s
z
[
v
1
]
×
s
z
[
v
3
]
+
.
.
.
+
s
z
[
v
k
−
1
]
×
s
z
[
v
k
]
(假设
u
有
k
个儿子)
sz[v_1] \times sz[v_2] + sz[v_1] \times sz[v_3] +... + sz[v_{k-1}] \times sz[v_k](假设 u 有 k 个儿子)
sz[v1]×sz[v2]+sz[v1]×sz[v3]+...+sz[vk−1]×sz[vk](假设u有k个儿子),就是将
u
u
u 的所有儿子的子树大小两两相乘,从而算出不在同一颗子树的点对数量;
对于某个点对
(
x
,
y
)
(x,y)
(x,y),
x
,
y
x,y
x,y 满足某种大小关系,假设
x
<
y
x < y
x<y,那么当
x
x
x 位于的子树比
y
y
y 位于的子树后遍历到时,这个点对会产生一个逆序对,
而对于
u
u
u 的直属儿子的全排列,从期望或者概率的角度出发,一定是恰好有一半的情况使得
x
x
x 在
y
y
y 前面,而剩下一半的情况
x
x
x 在
y
y
y 的后面,所以我们只需要将
k
!
÷
2
k! \div 2
k!÷2 就是产生逆序对的情况数(
k
k
k 是
u
u
u 的直属儿子数量)
但是,除了当前
u
u
u 的直属儿子遍历顺序会变,其他节点的遍历顺序也会变,全部累乘起来的话,其实和第一部分的
∏
s
o
n
i
\prod son_i
∏soni 是一样的
因此对于当前节点
u
u
u,其不同子树内贡献的逆序对数量为:
∏
s
o
n
i
2
×
s
u
m
\dfrac {\prod son_i}{2} \times sum
2∏soni×sum,
s
u
m
sum
sum 代表
u
u
u 中不同子树内的点对数量,对于不同的
u
u
u,可以将
∏
s
o
n
i
2
\dfrac {\prod son_i}{2}
2∏soni 提公因子出来,只累加
s
u
m
sum
sum 到最后算即可
关于
u
u
u 的来自不同子树内的点对数量,可以简单地使用类似前缀和来解决,详情看代码
d
f
s
0
dfs0
dfs0 的部分
#include<bits/stdc++.h>
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()
#define Debug(x, ed) std::cerr << #x << " = " << x << ed;
#define lowbit(x) ((x) & -(x))
const int INF=0x3f3f3f3f;
const long long INFLL=1e18;
typedef long long ll;
template<class T>
constexpr T power(T a, ll b){
T res = 1;
while(b){
if(b&1) res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
constexpr ll mul(ll a,ll b,ll mod){ //快速乘,避免两个long long相乘取模溢出
ll res = a * b - ll(1.L * a * b / mod) * mod;
res %= mod;
if(res < 0) res += mod; //误差
return res;
}
template<ll P>
struct MLL{
ll x;
constexpr MLL() = default;
constexpr MLL(ll x) : x(norm(x % getMod())) {}
static ll Mod;
constexpr static ll getMod(){
if(P > 0) return P;
return Mod;
}
constexpr static void setMod(int _Mod){
Mod = _Mod;
}
constexpr ll norm(ll x) const{
if(x < 0){
x += getMod();
}
if(x >= getMod()){
x -= getMod();
}
return x;
}
constexpr ll val() const{
return x;
}
explicit constexpr operator ll() const{
return x; //将结构体显示转换为ll类型: ll res = static_cast<ll>(OBJ)
}
constexpr MLL operator -() const{ //负号,等价于加上Mod
MLL res;
res.x = norm(getMod() - x);
return res;
}
constexpr MLL inv() const{
assert(x != 0);
return power(*this, getMod() - 2); //用费马小定理求逆
}
constexpr MLL& operator *= (MLL rhs) & { //& 表示“this”指针不能指向一个临时对象或const对象
x = mul(x, rhs.x, getMod()); //该函数只能被一个左值调用
return *this;
}
constexpr MLL& operator += (MLL rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MLL& operator -= (MLL rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MLL& operator /= (MLL rhs) & {
return *this *= rhs.inv();
}
friend constexpr MLL operator * (MLL lhs, MLL rhs){
MLL res = lhs;
res *= rhs;
return res;
}
friend constexpr MLL operator + (MLL lhs, MLL rhs){
MLL res = lhs;
res += rhs;
return res;
}
friend constexpr MLL operator - (MLL lhs, MLL rhs){
MLL res = lhs;
res -= rhs;
return res;
}
friend constexpr MLL operator / (MLL lhs, MLL rhs){
MLL res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream& operator >> (std::istream& is, MLL& a){
ll v;
is >> v;
a = MLL(v);
return is;
}
friend constexpr std::ostream& operator << (std::ostream& os, MLL& a){
return os << a.val();
}
friend constexpr bool operator == (MLL lhs, MLL rhs){
return lhs.val() == rhs.val();
}
friend constexpr bool operator != (MLL lhs, MLL rhs){
return lhs.val() != rhs.val();
}
};
const ll mod = 1e9 + 7;
using Z = MLL<mod>;
struct Comb {
int n;
std::vector<Z> _fac;
std::vector<Z> _invfac;
std::vector<Z> _inv;
Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
Comb(int n) : Comb() {
init(n);
}
void init(int m) {
m = std::min(1ll * m, Z::getMod() - 1);
if (m <= n) return; //已经处理完了需要的长度
_fac.resize(m + 1);
_invfac.resize(m + 1);
_inv.resize(m + 1);
for (int i = n + 1; i <= m; i++) {
_fac[i] = _fac[i - 1] * i;
}
_invfac[m] = _fac[m].inv();
for (int i = m; i > n; i--) { //线性递推逆元和阶乘逆元
_invfac[i - 1] = _invfac[i] * i;
_inv[i] = _invfac[i] * _fac[i - 1];
}
n = m; //新的长度
}
Z fac(int m) {
if (m > n) init(2 * m);
return _fac[m];
}
Z invfac(int m) {
if (m > n) init(2 * m);
return _invfac[m];
}
Z inv(int m) {
if (m > n) init(2 * m);
return _inv[m];
}
Z binom(int n, int m) { //二项式系数
if (n < m || m < 0) return 0;
return fac(n) * invfac(m) * invfac(n - m);
}
} comb;
const int N = 300050;
int fen[N]; //树状数组
std::vector<int> g[N];
Z k = 1; //dfs序数量
Z cnt = 0;
Z sum = 0;
int n, root;
int sz[N]; //子树大小
int query(int p){
int res = 0;
while(p > 0){
res += fen[p];
p -= lowbit(p);
}
return res;
}
void update(int p, int d){
while(p <= n){
fen[p] += d;
p += lowbit(p);
}
}
void dfs0(int u, int fa){
cnt += query(n) - query(u);
update(u, 1); //进入这个点,加上影响
int num = (g[u].size() - 1); //儿子数,减去父亲那条边
if(u == root) ++num; //根节点没有父亲
if(num > 0) k *= comb.fac(num); //排除叶子节点
sz[u] = 1;
Z s = 0; //子树大小前缀和
for(auto v : g[u])
if(v ^ fa){
dfs0(v, u);
sz[u] += sz[v];
sum += sz[v] * s;
s += sz[v];
}
update(u, -1); //离开这个点,删去影响
}
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
std::cin >> n >> root;
fore(i, 1, n){
int u, v;
std::cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs0(root, 0);
Z ans = k * cnt + k * sum / 2;
std::cout << ans;
return 0;
}