题目描述
已知有两个等长非降序序列
S
1
S_1
S1和
S
2
S_2
S2。先将
S
1
S_1
S1和
S
2
S_2
S2 合并为
S
3
S_3
S3 ,求
S
3
S_3
S3的中位数。
输入描述
第一行,序列
S
1
S_1
S1 和
S
2
S_2
S2 的长度
N
N
N ,
第二行,序列
S
1
S_1
S1 的
N
N
N 个整数,
第三行,序列
S
2
S_2
S2 的
N
N
N 个整数
输出描述
输出两个序列合并后序列 S 3 S_3 S3 的中位数。
思路分析
该题最优时间复杂度可以达到 O ( l o g ( 2 × N ) ) O(log(2 \times N)) O(log(2×N)) ,但从暴力的层面时间复杂度是 O ( 2 ∗ N ) O(2*N) O(2∗N) ,因此所有用 sort 排序输出的都是傻瓜。
首先我们对
S
1
S_1
S1 和
S
2
S_2
S2 取一下中位数,当
S
1
S_1
S1 的中位数小于
S
2
S_2
S2 的中位数时,那么
S
1
S_1
S1 中小于
S
1
m
1
{S_1}_{m1}
S1m1 一定小于
S
2
m
2
{S_2}_{m2}
S2m2 ,即中位数不会出现在
S
1
1
−
(
m
1
−
1
)
{S_1}_{1\ -\ (m1-1)}
S11 − (m1−1) 之中。同理,中位数也不会出现在
S
2
(
m
2
+
1
−
r
2
)
{S_2}_{(m2+1\ -\ r2)}
S2(m2+1 − r2)中。由此我们可以将这两部分删去,重新找新区间的中位数即可。如图
需要注意的是,我们递归过程中的出口是什么。一个坑点是合并以后的数集个数一定是偶数,因此我们需要找到
S
3
S_3
S3 中的第
N
N
N 大的数和第
N
+
1
N+1
N+1 大的数。如果我们把出口定在
l
1
=
=
r
1
l_1==r_1
l1==r1 的话就会出现对于两个中位数同时在
S
1
S_1
S1 或
S
2
S_2
S2 中时答案的误判,因此我们把出口放在
l
1
+
1
=
=
r
1
l_1+1==r_1
l1+1==r1 找一下中位数即可。
代码
#include <bits/stdc++.h>
using namespace std;
#define all(x) x.begin(), x.end()
#define bit1(x) __builtin_popcount(x)
#define Pqueue priority_queue
#define lc p << 1
#define rc p << 1 | 1
#define IOS ios::sync_with_stdio(false), cin.tie(0);
#define fi first
#define se second
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<ll, ll> PII;
const ll mod = 1000000007;
const ll N = 1e6 + 10;
const ld eps = 1e-9;
const ll inf = 1e18;
const ll P = 131;
const ll dir[8][2] = {1, 0, 0, 1, -1, 0, 0, -1, 1, 1, 1, -1, -1, 1, -1, -1};
int n;
int s1[N], s2[N];
void merge(int l1, int r1, int l2, int r2)
{
if (l1 + 1 == r1)
{
int a, b;
if (s1[r1] >= s2[r2])
{
a = s2[r2];
if (s1[l1] <= s2[l2])
b = s2[l2];
else
b = s1[l1];
}
else
{
a = s1[r1];
if (s1[l1] <= s2[l2])
b = s2[l2];
else
b = s1[l1];
}
cout << (a + b) / 2;
return;
}
int m1 = l1 + r1 >> 1, m2 = l2 + r2 >> 1;
if (s1[m1] > s2[m2])
merge(l1, m1, m2, r2);
else
merge(m1, r1, l2, m2);
}
void solve()
{
cin >> n;
for (int i = 1; i <= n; i++)
cin >> s1[i];
for (int i = 1; i <= n; i++)
cin >> s2[i];
if (n == 1)
return void(cout << (s1[1] + s2[1]) / 2); //对长度为1的数集的特判
merge(1, n, 1, n);
}
int main()
{
int T = 1;
// cin>>T;
while (T--)
solve();
return 0;
}
/*
oxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxox
x o
o _/_/_/_/ _/ x
x _/ o
o _/_/_/_/ _/ _/_/ _/_/ _/_/_/ _/_/ _/_/_/ _/_/ _/_/_/ _/ _/ _/ x
x _/ _/_/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ o
o _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/_/ x
x _/ _/ _/_/ _/ _/ _/ _/_/_/ _/_/ _/ _/ _/ _/ _/ o
o _/ _/ _/ x
x _/ _/_/ _/ o
o x
xoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxo
*/