二分问题之前遇到很多次了,不过一直是手写完整二分,现在转变一下想法,直接使用函数lower_bound和upper_bound更方便
lower_bound
有序数组中 查找第一个不小于指定值的位置。
本质二分代码:
int lower_bound_custom(int* arr, int n, int val) {
int low = 0, high = n; // 查找范围 [low, high)
while (low < high) {
int mid = low + (high - low) / 2;
if (arr[mid] < val) {
low = mid + 1; // 继续查找右区间
} else {
high = mid; // 继续查找左区间
}
}
return low; // 返回第一个不小于 val 的位置
}
upper_bound
有序数组中 查找第一个大于指定值的位置
本质二分代码:
int upper_bound_custom(int* arr, int n, int val) {
int low = 0, high = n; // 查找范围 [low, high)
while (low < high) {
int mid = low + (high - low) / 2;
if (arr[mid] <= val) {
low = mid + 1; // 继续查找右区间
} else {
high = mid; // 继续查找左区间
}
}
return low; // 返回第一个大于 val 的位置
}
注意 lower_bound和upper_bound 是指针用法,最后返回的是位置索引
真题实战
题目链接:1.递增三元组 - 蓝桥云课
利用 lower_bound和upper_bound 减少时间复杂度
代码一:
纯暴力思想,发现有2个测试样例无法通过,时间超时,此时时间复杂度为O(n^3)
#include<bits/stdc++.h>
using namespace std;
int n;
int a[100010],b[100010],c[100010];
long long sum=0;
int main()
{
cin>>n;
for(int i=1; i<=n; i++) cin>>a[i];
for(int i=1; i<=n; i++) cin>>b[i];
for(int i=1; i<=n; i++) cin>>c[i];
for(int i=1; i<=n; i++)
{
for(int j=1; j<=n; j++)
{
for(int k=1; k<=n; k++)
{
if(a[i]<b[j] && b[j]<c[k])
{
sum++;
}
}
}
}
cout<<sum<<endl;
return 0;
}
代码二:
优化遍历,发现只需要两重for循环,找到每个b[i]位置符合的a[]个数和c[]个数,累成再累加,但是还是有一个样例无法通过,此时时间复杂度为O(n^2)
#include<bits/stdc++.h>
using namespace std;
int n;
int a[100010],b[100010],c[100010];
long long sum=0;
int main()
{
cin>>n;
for(int i=1; i<=n; i++) cin>>a[i];
for(int i=1; i<=n; i++) cin>>b[i];
for(int i=1; i<=n; i++) cin>>c[i];
for(int i=1;i<=n;i++)
{
int sum_a=0,sum_c=0;
for(int j=1;j<=n;j++)
{
if(a[j]<b[i])
{
sum_a++;
}
if(b[i]<c[j])
{
sum_c++;
}
}
sum+=1LL*sum_a*sum_c;
}
cout<<sum<<endl;
return 0;
}
代码三:
最后想到通过二分排序,减少时间复杂度,通过所有测试样例,此时时间复杂度为O(n^logn)
#include<bits/stdc++.h>
using namespace std;
int n;
int a[100010],b[100010],c[100010];
long long sum=0;
int main()
{
cin>>n;
for(int i=1; i<=n; i++) cin>>a[i];
for(int i=1; i<=n; i++) cin>>b[i];
for(int i=1; i<=n; i++) cin>>c[i];
sort(a+1,a+n+1);
sort(b+1,b+n+1);
sort(c+1,c+n+1);
for(int j=1;j<=n;j++)
{
int sum_a=lower_bound(a+1,a+n+1,b[j])-(a+1); //满足 a[i] < b[j]
int sum_c=(c+n+1)-upper_bound(c+1,c+n+1,b[j]); //满足 b[j] < c[k]
sum+=1LL*sum_a*sum_c; //可能超int型范围
}
cout<<sum<<endl;
return 0;
}