F. Sasha and the Wedding Binary Search Tree
题意
给定一颗二叉搜索树,规定树上的所有点的点权都在范围 [ 1 , C ] [1, C] [1,C] 内,树上的某些节点点权已知,某些节点点权未知,求出合法的二叉搜索树的数量
思路
由于是二叉搜索树,所以左子树的点权小于等于右子树的点权,我们先进行一次先序遍历,得到一个 o r d e r order order 遍历顺序的数组,其中某些段是未知的点权,我们可以通过其左端点和右端点的点权来约束这些未知点的点权
假如这个未知段左边紧邻的第一个已知点权为 L L L,右边紧邻的第一个已知点权为 R R R,那么显然我们这个未知段的点权是 [ L , R ] [L,R] [L,R] 的,假设未知段长度为 l e n len len,问题等价于选出 l e n len len 个数,按照非递减的顺序摆放的方案数,允许重复数字出现。
我们将问题再转化一下,由于是非递减,所以一定有: v a l i ≤ v a l i + 1 val_i \leq val_{i + 1} vali≤vali+1,那么这一段的差分数组一定是大于等于 0 0 0 的一段,并且总和为 R − L R - L R−L,要加上最后一个 R R R 的位置,长度变为 l e n + 1 len + 1 len+1,例如: L = 2 , R = 5 , l e n = 3 L = 2, R= 5, len = 3 L=2,R=5,len=3,我们构造这样一个方案: 2 , 3 , 3 , 4 , 5 2, 3 , 3, 4, 5 2,3,3,4,5,差分数组为: 2 , 1 , 0 , 1 , 1 2, 1, 0, 1, 1 2,1,0,1,1,忽略第一个位置的话,后面所有元素的和就是 R − L = 3 R - L = 3 R−L=3,那么问题成功转化为了:将 R − L R - L R−L 个相同小球放入 l e n + 1 len + 1 len+1 个不同盒子中(允许空盒子)的方案数,即为: C R − L + l e n l e n C_{R - L + len}^{len} CR−L+lenlen
由于这里 R − L R - L R−L 很大,所以不能直接预处理阶乘,注意到 ∑ l e n ≤ n \sum len \leq n ∑len≤n,我们只需要利用公式: C R − L + l e n l e n = ( R − L + l e n ) ⋅ ( R − L + l e n − 1 ) ⋅ . . . ⋅ ( R − L + 1 ) l e n ! C_{R - L + len}^{len} = \dfrac{(R - L + len) \cdot (R - L + len - 1) \cdot ... \cdot (R - L + 1)}{len!} CR−L+lenlen=len!(R−L+len)⋅(R−L+len−1)⋅...⋅(R−L+1) 来计算即可
#include<bits/stdc++.h>
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()
#define Debug(x, ed) std::cerr << #x << " = " << x << ed;
const int INF=0x3f3f3f3f;
const long long INFLL=1e18;
typedef long long ll;
template<class T>
constexpr T power(T a, ll b){
T res = 1;
while(b){
if(b&1) res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
constexpr ll mul(ll a,ll b,ll mod){ //快速乘,避免两个long long相乘取模溢出
ll res = a * b - ll(1.L * a * b / mod) * mod;
res %= mod;
if(res < 0) res += mod; //误差
return res;
}
template<ll P>
struct MLL{
ll x;
constexpr MLL() = default;
constexpr MLL(ll x) : x(norm(x % getMod())) {}
static ll Mod;
constexpr static ll getMod(){
if(P > 0) return P;
return Mod;
}
constexpr static void setMod(int _Mod){
Mod = _Mod;
}
constexpr ll norm(ll x) const{
if(x < 0){
x += getMod();
}
if(x >= getMod()){
x -= getMod();
}
return x;
}
constexpr ll val() const{
return x;
}
explicit constexpr operator ll() const{
return x; //将结构体显示转换为ll类型: ll res = static_cast<ll>(OBJ)
}
constexpr MLL operator -() const{ //负号,等价于加上Mod
MLL res;
res.x = norm(getMod() - x);
return res;
}
constexpr MLL inv() const{
assert(x != 0);
return power(*this, getMod() - 2); //用费马小定理求逆
}
constexpr MLL& operator *= (MLL rhs) & { //& 表示“this”指针不能指向一个临时对象或const对象
x = mul(x, rhs.x, getMod()); //该函数只能被一个左值调用
return *this;
}
constexpr MLL& operator += (MLL rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MLL& operator -= (MLL rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MLL& operator /= (MLL rhs) & {
return *this *= rhs.inv();
}
friend constexpr MLL operator * (MLL lhs, MLL rhs){
MLL res = lhs;
res *= rhs;
return res;
}
friend constexpr MLL operator + (MLL lhs, MLL rhs){
MLL res = lhs;
res += rhs;
return res;
}
friend constexpr MLL operator - (MLL lhs, MLL rhs){
MLL res = lhs;
res -= rhs;
return res;
}
friend constexpr MLL operator / (MLL lhs, MLL rhs){
MLL res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream& operator >> (std::istream& is, MLL& a){
ll v;
is >> v;
a = MLL(v);
return is;
}
friend constexpr std::ostream& operator << (std::ostream& os, MLL& a){
return os << a.val();
}
friend constexpr bool operator == (MLL lhs, MLL rhs){
return lhs.val() == rhs.val();
}
friend constexpr bool operator != (MLL lhs, MLL rhs){
return lhs.val() != rhs.val();
}
};
const ll mod = 998244353;
using Z = MLL<mod>;
struct Comb {
int n;
std::vector<Z> _fac;
std::vector<Z> _invfac;
std::vector<Z> _inv;
Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
Comb(int n) : Comb() {
init(n);
}
void init(int m) {
m = std::min(1ll * m, Z::getMod() - 1);
if (m <= n) return; //已经处理完了需要的长度
_fac.resize(m + 1);
_invfac.resize(m + 1);
_inv.resize(m + 1);
for (int i = n + 1; i <= m; i++) {
_fac[i] = _fac[i - 1] * i;
}
_invfac[m] = _fac[m].inv();
for (int i = m; i > n; i--) { //线性递推逆元和阶乘逆元
_invfac[i - 1] = _invfac[i] * i;
_inv[i] = _invfac[i] * _fac[i - 1];
}
n = m; //新的长度
}
Z fac(int m) {
if (m > n) init(2 * m);
return _fac[m];
}
Z invfac(int m) {
if (m > n) init(2 * m);
return _invfac[m];
}
Z inv(int m) {
if (m > n) init(2 * m);
return _inv[m];
}
Z binom(int n, int m) { //二项式系数
if (n < m || m < 0) return 0;
return fac(n) * invfac(m) * invfac(n - m);
}
} comb;
const int N = 500050;
int L[N], R[N];
ll val[N];
std::vector<ll> order;
void dfs(int u){
if(~L[u]) dfs(L[u]);
order.push_back(val[u]);
if(~R[u]) dfs(R[u]);
}
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int t;
std::cin >> t;
while(t--){
int n;
ll C;
std::cin >> n >> C;
fore(i, 1, n + 1) std::cin >> L[i] >> R[i] >> val[i];
order.clear();
order.push_back(1);
dfs(1);
order.push_back(C);
Z ans = 1;
int idx = 1;
while(idx <= n){
while(idx <= n && ~order[idx]) ++idx;
int l = order[idx - 1];
int len = 0;
while(idx <= n && order[idx] == -1){
++idx;
len += 1;
}
int r = order[idx];
Z res = 1;
for(int i = r - l + len; i >= r - l + 1; --i) res *= i;
res *= comb.invfac(len);
ans *= res;
}
std::cout << ans << endl;
}
return 0;
}