【C++】剖析lower_bound & upper_bound
接口
先看看接口
std::lower_bound
default (1)
template <class ForwardIterator, class T> ForwardIterator lower_bound (ForwardIterator first, ForwardIterator last,const T& val);
custom (2)
template <class ForwardIterator, class T, class Compare> ForwardIterator lower_bound (ForwardIterator first, ForwardIterator last,const T& val, Compare comp);
std::upper_bound
default (1)
template <class ForwardIterator, class T> ForwardIterator upper_bound (ForwardIterator first, ForwardIterator last,const T& val);
custom (2)
template <class ForwardIterator, class T, class Compare> ForwardIterator upper_bound (ForwardIterator first, ForwardIterator last,const T& val, Compare comp);
实现
类似的底层实现
参考:https://cplusplus.com/reference/algorithm/lower_bound/,https://cplusplus.com/reference/algorithm/upper_bound/
std::lower_bound
template <class ForwardIterator, class T>
ForwardIterator lower_bound (ForwardIterator first, ForwardIterator last, const T& val)
{
ForwardIterator it;
iterator_traits<ForwardIterator>::difference_type count, step;
count = distance(first,last);
while (count>0)
{
it = first; step=count/2; advance (it,step);
if (*it<val) { // or: if (comp(*it,val)), for version (2)
first=++it;
count-=step+1;
}
else count=step;
}
return first;
}
std::upper_bound
template <class ForwardIterator, class T>
ForwardIterator upper_bound (ForwardIterator first, ForwardIterator last, const T& val)
{
ForwardIterator it;
iterator_traits<ForwardIterator>::difference_type count, step;
count = std::distance(first,last);
while (count>0)
{
it = first; step=count/2; std::advance (it,step);
if (!(val<*it)) // or: if (!comp(val,*it)), for version (2)
{ first=++it; count-=step+1; }
else count=step;
}
return first;
}
使用
这两个结果对于升序还是降序、传入查找的参数类型是否和数组类型一致都会影响具体的使用
升序数组
根据下面的源码,对于升序数组,内置类型或者重载了operator<自定义数据类型的不用多数了,直接用default 1即可。因为底层也是使用operator<
// lower_bound-------------------------------------
_EXPORT_STD template <class _FwdIt, class _Ty, class _Pr>
_NODISCARD _CONSTEXPR20 _FwdIt lower_bound(_FwdIt _First, const _FwdIt _Last, const _Ty& _Val, _Pr _Pred) {
// find first element not before _Val
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
_Iter_diff_t<_FwdIt> _Count = _STD distance(_UFirst, _STD _Get_unwrapped(_Last));
while (0 < _Count) { // divide and conquer, find half that contains answer
const _Iter_diff_t<_FwdIt> _Count2 = _Count / 2;
const auto _UMid = _STD next(_UFirst, _Count2);
if (_Pred(*_UMid, _Val)) { // try top half
_UFirst = _STD _Next_iter(_UMid);
_Count -= _Count2 + 1;
} else {
_Count = _Count2;
}
}
_STD _Seek_wrapped(_First, _UFirst);
return _First;
}
_EXPORT_STD template <class _FwdIt, class _Ty>
_NODISCARD _CONSTEXPR20 _FwdIt lower_bound(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
// find first element not before _Val
return _STD lower_bound(_First, _Last, _Val, less<>{});// ⭐
}
// upper_bound---------------------------------------
_EXPORT_STD template <class _FwdIt, class _Ty, class _Pr>
_NODISCARD _CONSTEXPR20 _FwdIt upper_bound(_FwdIt _First, _FwdIt _Last, const _Ty& _Val, _Pr _Pred) {
// find first element that _Val is before
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
_Iter_diff_t<_FwdIt> _Count = _STD distance(_UFirst, _STD _Get_unwrapped(_Last));
while (0 < _Count) { // divide and conquer, find half that contains answer
_Iter_diff_t<_FwdIt> _Count2 = _Count / 2;
const auto _UMid = _STD next(_UFirst, _Count2);
if (_Pred(_Val, *_UMid)) {
_Count = _Count2;
} else { // try top half
_UFirst = _STD _Next_iter(_UMid);
_Count -= _Count2 + 1;
}
}
_STD _Seek_wrapped(_First, _UFirst);
return _First;
}
_EXPORT_STD template <class _FwdIt, class _Ty>
_NODISCARD _CONSTEXPR20 _FwdIt upper_bound(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
// find first element that _Val is before
return _STD upper_bound(_First, _Last, _Val, less<>{});// ⭐
}
降序数组
但如果是降序数组就需要传入比较参数了。例如
#include <iostream>
#include <vector>
#include <algorithm>
struct Person {
std::string name;
int age;
};
// 自定义比较函数:根据 age 降序排列
bool compareByAgeDesc(const Person& p1, const Person& p2) {
return p1.age > p2.age; // 按年龄降序排列
}
int main() {
// 创建一个 Person 类型的 vector,并按年龄降序排序
std::vector<Person> people = {
{"Alice", 30},
{"Bob", 25},
{"Charlie", 35},
{"David", 40}
};
// 按年龄降序排列
std::sort(people.begin(), people.end(), compareByAgeDesc);
// 查找年龄大于或等于 30 的第一个 Person
Person p = {"", 30}; // 目标年龄是 30
auto it_lower = std::lower_bound(people.begin(), people.end(), p, [](const Person& p1, const Person& p2) {
return p1.age > p2.age; // 按降序比较
});
if (it_lower != people.end()) {
std::cout << "Lower bound (first person >= 30): " << it_lower->name << ", Age: " << it_lower->age << std::endl;
}
// 查找年龄大于 30 的第一个 Person
p.age = 30; // 查找大于 30 的
auto it_upper = std::upper_bound(people.begin(), people.end(), p, [](const Person& p1, const Person& p2) {
return p1.age > p2.age; // 按降序比较
});
if (it_upper != people.end()) {
std::cout << "Upper bound (first person > 30): " << it_upper->name << ", Age: " << it_upper->age << std::endl;
}
return 0;
}
非同类型元素查找
上面都是传入查找的参数类型和数组元素类型一致,但如果不一致呢,如下,注意两者lamda的形参顺序和比较方式。这种不一致其实都是为了符合底层实现。
#include <iostream>
#include <vector>
#include <algorithm>
struct Person {
std::string name;
int age;
};
// 自定义比较函数:根据 age 降序排列
bool compareByAgeDesc(const Person& p1, const Person& p2) {
return p1.age > p2.age; // 按年龄降序排列
}
int main() {
// 创建一个 Person 类型的 vector,并按年龄降序排序
std::vector<Person> people = {
{"Alice", 30},
{"Bob", 25},
{"Charlie", 35},
{"David", 40}
};
// 按年龄降序排列
std::sort(people.begin(), people.end(), compareByAgeDesc);
// 查找年龄大于或等于 30 的第一个 Person
int age = 30; // 目标年龄是 30
auto it_lower = std::lower_bound(people.begin(), people.end(), age, [](const Person& p, int age) {
return p.age > age; // ⭐按降序比较,查找第一个 age >= 30 的 Person
});
if (it_lower != people.end()) {
std::cout << "Lower bound (first person >= 30): " << it_lower->name << ", Age: " << it_lower->age << std::endl;
}
// 查找年龄大于 30 的第一个 Person
auto it_upper = std::upper_bound(people.begin(), people.end(), age, [](int age, const Person& p) {
return age > p.age; // ⭐按降序比较,查找第一个 age > 30 的 Person
});
if (it_upper != people.end()) {
std::cout << "Upper bound (first person > 30): " << it_upper->name << ", Age: " << it_upper->age << std::endl;
}
return 0;
}
一个是const Person& p, int age; return p.age > age;
,对应if (comp(*it,val))
,一个是int age, const Person& p; return age > p.age;
对应if (!comp(val,*it))
,
所以一切使用上的细节都来源于底层实现!