在统计和数据分析中,我们经常会遇到求最大值、最小值、中位数、四分位数、Top K等类似需求,其实它们都属于顺序统计量,本文将对顺序统计量的定义和求解算法进行介绍,重点介绍如何在最差时间复杂度也是线性的情况下求解第k大元素。
1. 顺序统计量与选择问题
在一个有 n n n个元素的集合中,第 i i i个顺序统计量是该集合中第 i i i小的元素。例如在集合 ( 1 , 3 , 5 , 2 ) (1, 3, 5, 2) (1,3,5,2)中,第2个顺序统计量为2。
从一个有 n n n个元素的集合中,选择出(求解)其第 i i i个顺序统计量的问题被称为选择问题。选择问题的输入输出如下:
输入:一个包含
n
n
n个不同的数的集合
A
A
A和一个数
i
i
i(
1
≤
i
≤
n
1 \leq i \leq n
1≤i≤n);
输出:元素
x
∈
A
x \in A
x∈A,它恰大于A中其它
(
i
−
1
)
(i-1)
(i−1)个元素。
2. 选择问题的求解方法
显然,对输入集合 A A A进行排序之后就可以解决选择问题,使用堆排序或归并排序对输入集合进行排序,然后在排序后的数组中标出第 i i i个元素,即可在 O ( n log n ) O(n\log{n}) O(nlogn)时间内完成求解,但还有更快的算法。
2.1 最大值与最小值
我们先考虑选择问题的特殊情况,只求解最大值或最小值,可以发现很容易在
O
(
n
)
O(n)
O(n)时间内完成求解。只需要遍历数组,进行(n-1)
次比较即可。以求解最小值为例,伪码如下:
MINIMUM(A)
min = A[1]
n = length[A]
for(i=2; i<=n; i++)
if(min > A[i])
min = A[i]
return min
在此基础上,增加一点难度,我们希望同时找最大值和最小值,是否还可以在 O ( n ) O(n) O(n)时间内完成求解呢?答案是肯定的,在遍历过程中不再每次比较一个元素,而是每次比较两个元素,两个元素中较小的元素与当前最小值比较,较大的元素和当前最大值比较,即每对元素需要3次比较即可。伪码如下:
MAX-MINIMUM(A)
if(length[A] is odd)
min = A[1]
max = min
else
min = MIN(A[1], A[2])
max = MAX(A[1], A[2])
i++
while(i <= length[A])
min = MIN(MIN(A[i], A[i+1]), min)
max = MAX(MAX(A[i], A[i+1]), max)
i = i + 2
return min, max
在当前最大值和最小值初始值的设定上,如果n
是奇数,将最大值和最小值均设置为第一个元素的值;如果n
是偶数,就对前两个元素做一次比较来决定最大值与最小值的初始值。因此,如果n
是奇数,那么总共做了
3
⌊
n
/
2
⌋
3\lfloor n/2 \rfloor
3⌊n/2⌋次比较;如果n
是偶数,那么总共做了
3
n
/
2
−
2
3n/2-2
3n/2−2次比较,时间复杂度为
O
(
n
)
O(n)
O(n)。
2.2 期望时间为线性的选择问题
我们回到一般的选择问题,看起来一般的选择问题比求最小值和最大值复杂得多,但神奇的算法仍然可以让我们在平均为线性的时间完成求解。这里再次使用了分治思想,借鉴了快速排序的随机划分方法,如果刚好划分元的左边有(i-1)
个元素,则找到第i
小的元素;否则,在划分元的左侧或右侧继续进行随机划分。伪码如下:
RANDOMIZED-SELECT(A, p, r, i) // A为数组,p为数组左边界,r为数组右边界,i为待求的顺序统计量序号
if(p == r) // 临界问题处理
return A[p]
q = RANDOMIZED-PARTITION(A, p, r) //进行划分,返回划分元下标
k = q – p + 1 // k=rank(A[q]) in A[p,…,r], 返回划分元的序号
if(i == k)
return A[q]
else if(i < k)
return RANDOMIZED-SELECT(A, p, q - 1, i)
else
return RANDOMIZED-SELECT(A, q + 1, r, i – k)
可以证明在平均情况下,算法的时间复杂度为 O ( n ) O(n) O(n)。而当运气不好时,每次都只能去除一个元素,算法的时间复杂度就可能达到 O ( n 2 ) O(n^2) O(n2)。
2.3 最差时间为线性的选择问题
在上述RANDOMIZED-SELECT
算法的基础上,保证每次对数组的划分是个好划分,我们就能进一步在最差情况下也用线性时间解决选择问题。主要步骤如下:
- 将
n
个元素每5个分为一组,一共 ⌈ n / 5 ⌉ \lceil n/5 \rceil ⌈n/5⌉组。最后一组有n mod 5
个元素。 - 对每组进行排序,取其中位数。若最后一组有偶数个元素,则取较小的中位数。
- 递归地使用本算法寻找
⌈
n
/
5
⌉
\lceil n/5 \rceil
⌈n/5⌉个中位数的中位数
x
。 - 用
x
作为划分元对数组A
进行划分,并设x
是第k
个最小元。 - 如果
i = k
,则返回x
;否则如果i < k
,则找左区间的第i
个最小元;如果i > k
,则找右区间的第i - k
个最小元。
伪码如下:
SELECT(A, p, r, i)
if(r - p <= 140)
用简单的排序算法对数组A[p..r进行排序
return A[p + k - 1]
n = r - p + 1
for(i = 0; i <= floor(n/5); i++) //寻找每组的中位数
将A[p+5*i]至A[p+5*i+4]的第3小元素与A[p+i]交换位置
x = SELECT(A, p, p+floor(n/5), floor(n/10)) //找中位数的中位数
i = PARTITION(A, p, r, x)
j = i - p + 1
if(k <= j)
return SELECT(A, p, i, k)
else
return SELECT(A, i + 1, r, k - j)
3. 程序代码
以下C语言程序代码实现了最坏情况为线性的select
算法,将“求数组a[1..n]
中第k大的元素”转化为“求数组a[1..n]
中第(n-k+1)
小的元素”。递归调用select
时,设置数组长度不大于140时,即直接使用插入排序。
3.1 linearSelect_kth.cpp
#include <stdio.h>
#include <stdlib.h>
#define N 1000000 //定义输入数组的最大长度
#define LEN 5 //定义select中每组元素的个数
int a[N];
void swap(int *a, int *b) { //交换 a 与 b 的值
int tmp = *a;
*a = *b;
*b = tmp;
}
int partition(int a[], int low, int high, int pivot) { //将数组a[low..high]划分为 <= pivot和 > pivot的两部分
int x;
int i = low - 1;
int j;
for (j = low; j < high; j++) { //在数组中找到值等于privot的元素作为主元,交换到数组最右端
if (a[j] == pivot) {
swap(&a[j], &a[high]);
}
}
x = a[high];
for (j = low; j < high; j++) { //维护低区a[low..i] <= x, 高区a[i+1..j-1] > x
if (a[j] <= x) { //如果发现a[j] <= x,则将a[j]交换到低区
i++;
swap(&a[i], &a[j]);
}
}
swap(&a[i + 1], &a[high]); //将主元与最左的大于 x 的元素a[i+1]交换,此时主元到了它应在的位置
return i + 1; //返回分区完成后主元所在的新下标
}
void insertSort(int a[], int low, int high) { //对a[low..high]进行插入排序
int i, j;
for (i = low + 1; i <= high; i++) {
int temp = a[i];
for (j = i - 1; j >= low && temp < a[j]; j--) {
a[j + 1] = a[j];
}
a[j + 1] = temp;
}
}
int select(int a[], int begin, int end, int k) { //选出数组a[begin..end]的第k小元素
int length = end - begin + 1; //数组长度,即数组中元素的个数
if (length <= 140) { //长度较小,直接用插入排序
insertSort(a, begin, end);
return a[begin + k - 1];
}
int groups = (length + LEN) / LEN; //组数
int i;
for (i = 0; i < groups; i++) {
int left = begin + LEN * i; //第i组的左边界
int right = (begin + LEN * i + LEN - 1) > end ? end : (begin + LEN * i + LEN - 1); //第i组的右边界
insertSort(a, left, right); //组内进行插入排序
//将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
int mid = (left + right) / 2;
swap(&a[begin + i], &a[mid]);
}
int pivot = select(a, begin, begin + groups - 1, (groups + 1) / 2); //找出中位数的中位数
int p = partition(a, begin, end, pivot); //用中位数的中位数作为划分的主元
int leftNum = p - begin; //低区元素的数量
if (k == leftNum + 1) {
return a[p];
}
else if (k <= leftNum) {
return select(a, begin, p - 1, k); //在低区递归调用select来找出第k小的元素
}
else {
return select(a, p + 1, end, k - leftNum -1); //在高区递归调用select来找出第(k-leftNum-1)小的元素
}
}
int main() {
FILE *fp = fopen("data_1022.txt","r"); //打开文件
if (fp == NULL) {
printf("Can not open the file!\n");
exit(0);
}
int i = 0;
while (fscanf(fp, "%d\n", &a[i]) != EOF) { //读取文件中的数据到数组a[]中
i++;
}
fclose(fp); //关闭文件
int k;
while (1) {
printf("Please enter an integer k, and you will get the k-th largest element in the array!\n");
printf("(Enter negative or zero to quit): ");
scanf("%d", &k);
if (k <= 0) {
printf("Bye\n");
break;
}
printf("The %dth largest element in the array is: %d\n", k, select(a, 0, i - 1, i - k + 1));
printf("\n==================================================================================\n");
}
return 0;
}
3.2 linearSelect_kth_grouplenth_5_vs_7_vs_3.cpp
然后,在3.1节代码的基础上,我尝试改变每组元素的个数,分别设置每组元素个数为5、7、3,比较算法运行时间的差异。
#include <stdio.h>
#include <stdlib.h>
#include <windows.h>
#define N 1000000 //定义输入数组的最大长度
#define LEN1 5 //尝试改变select中每组元素的个数
#define LEN2 7
#define LEN3 3
int a[N];
void swap(int *a, int *b) { //交换 a 与 b 的值
int tmp = *a;
*a = *b;
*b = tmp;
}
int partition(int a[], int low, int high, int pivot) { //将数组a[low..high]划分为 <= pivot和 > pivot的两部分
int x;
int i = low - 1;
int j;
for (j = low; j < high; j++) { //在数组中找到值等于privot的元素作为主元,交换到数组最右端
if (a[j] == pivot) {
swap(&a[j], &a[high]);
}
}
x = a[high];
for (j = low; j < high; j++) { //维护低区a[low..i] <= x, 高区a[i+1..j-1] > x
if (a[j] <= x) { //如果发现a[j] <= x,则将a[j]交换到低区
i++;
swap(&a[i], &a[j]);
}
}
swap(&a[i + 1], &a[high]); //将主元与最左的大于 x 的元素a[i+1]交换,此时主元到了它应在的位置
return i + 1; //返回分区完成后主元所在的新下标
}
void insertSort(int a[], int low, int high) { //对a[low..high]进行插入排序
int i, j;
for (i = low + 1; i <= high; i++) {
int temp = a[i];
for (j = i - 1; j >= low && temp < a[j]; j--) {
a[j + 1] = a[j];
}
a[j + 1] = temp;
}
}
int select_5(int a[], int begin, int end, int k) { //选出数组a[begin..end]的第k小元素,分组长度为5
int length = end - begin + 1; //数组长度,即数组中元素的个数
if (length <= 140) { //长度较小,直接用插入排序
insertSort(a, begin, end);
return a[begin + k - 1];
}
int groups = (length + LEN1) / LEN1; //组数
int i;
for (i = 0; i < groups; i++) {
int left = begin + LEN1 * i; //第i组的左边界
int right = (begin + LEN1 * i + LEN1 - 1) > end ? end : (begin + LEN1 * i + LEN1 - 1); //第i组的右边界
insertSort(a, left, right); //组内进行插入排序
//将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
int mid = (left + right) / 2;
swap(&a[begin + i], &a[mid]);
}
int pivot = select_5(a, begin, begin + groups - 1, (groups + 1) / 2); //找出中位数的中位数
int p = partition(a, begin, end, pivot); //用中位数的中位数作为划分的主元
int leftNum = p - begin; //低区元素的数量
if (k == leftNum + 1) {
return a[p];
}
else if (k <= leftNum) {
return select_5(a, begin, p - 1, k); //在低区递归调用select来找出第k小的元素
}
else {
return select_5(a, p + 1, end, k - leftNum -1); //在高区递归调用select来找出第(k-leftNum-1)小的元素
}
}
int select_7(int a[], int begin, int end, int k) {
int length = end - begin + 1;
if (length <= 140) {
insertSort(a, begin, end);
return a[begin + k - 1];
}
int groups = (length + LEN2) / LEN2;
int i;
for (i = 0; i < groups; i++) {
int left = begin + LEN2 * i; //第i组的左边界
int right = (begin + LEN2 * i + LEN2 - 1) > end ? end : (begin + LEN2 * i + LEN2 - 1); //第i组的右边界
insertSort(a, left, right); //组内进行插入排序
//将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
int mid = (left + right) / 2;
swap(&a[begin + i], &a[mid]);
}
int pivot = select_7(a, begin, begin + groups - 1, (groups + 1) / 2); //找出中位数的中位数
int p = partition(a, begin, end, pivot); //用中位数的中位数作为划分的主元
int leftNum = p - begin; //低区元素的数量
if (k == leftNum + 1) {
return a[p];
}
else if (k <= leftNum) {
return select_7(a, begin, p - 1, k); //在低区递归调用select来找出第k小的元素
}
else {
return select_7(a, p + 1, end, k - leftNum -1); //在高区递归调用select来找出第(k-leftNum-1)小的元素
}
}
int select_3(int a[], int begin, int end, int k) {
int length = end - begin + 1;
if (length <= 140) {
insertSort(a, begin, end);
return a[begin + k - 1];
}
int groups = (length + LEN3) / LEN3;
int i;
for (i = 0; i < groups; i++) {
int left = begin + LEN3 * i; //第i组的左边界
int right = (begin + LEN3 * i + LEN3 - 1) > end ? end : (begin + LEN3 * i + LEN3 - 1); //第i组的右边界
insertSort(a, left, right); //组内进行插入排序
//将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
int mid = (left + right) / 2;
swap(&a[begin + i], &a[mid]);
}
int pivot = select_3(a, begin, begin + groups - 1, (groups + 1) / 2); //找出中位数的中位数
int p = partition(a, begin, end, pivot); //用中位数的中位数作为划分的主元
int leftNum = p - begin; //低区元素的数量
if (k == leftNum + 1) {
return a[p];
}
else if (k <= leftNum) {
return select_3(a, begin, p - 1, k); //在低区递归调用select来找出第k小的元素
}
else {
return select_3(a, p + 1, end, k - leftNum -1); //在高区递归调用select来找出第(k-leftNum-1)小的元素
}
}
int main() {
FILE *fp = fopen("data_1022.txt","r"); //打开文件
if (fp == NULL) {
printf("Can not open the file!\n");
exit(0);
}
int i = 0;
while (fscanf(fp, "%d\n", &a[i]) != EOF) { //读取文件中的数据到数组a[]中
i++;
}
fclose(fp); //关闭文件
printf("Please enter an integer k, and you will get the k-th largest element in the array:\n");
int k;
scanf("%d", &k);
printf("********************* Group length is 5, array size is 945800**********************\n");
LARGE_INTEGER nFreq;
LARGE_INTEGER nBeginTime;
LARGE_INTEGER nEndTime;
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_5(a, 0, i - 1, i - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
double time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("********************* Group length is 7, array size is 945800 **********************\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_7(a, 0, i - 1, i - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("********************* Group length is 3, array size is 945800 **********************\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_3(a, 0, i - 1, i - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("====================== Group length is 5, array size is 10000 ======================\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_5(a, 0, 9999, 10000 - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("====================== Group length is 7, array size is 10000 ======================\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_7(a, 0, 9999, 10000 - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("====================== Group length is 3, array size is 10000 ======================\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_3(a, 0, 9999, 10000 - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("######################## Group length is 5, array size is 1000 ########################\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_5(a, 0, 999, 1000 - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("######################## Group length is 7, array size is 1000 ########################\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_7(a, 0, 999, 1000 - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
printf("######################## Group length is 3, array size is 1000 ########################\n");
QueryPerformanceFrequency(&nFreq);
QueryPerformanceCounter(&nBeginTime);
printf("The %dth largest element in the array is: %d\n", k, select_3(a, 0, 999, 1000 - k + 1));
QueryPerformanceCounter(&nEndTime); //计时结束
time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
printf("Running time: %lfms\n\n", time);
return 0;
}
4. 运行结果
程序使用的测试数据可在本文所附资源处或点击此链接下载。在linearSelect_kth.cpp中,设置分组长度为5。运行该程序,程序循环提示输入整数k,按下回车后会输出第k大的元素,直至输入一个负数或0,程序终止。
在linearSelect_kth_grouplenth_5_vs_7_vs_3.cpp中,大致比较了算法在分组长度为5、7、3以及在不同问题规模的情况下的运行时间。运行该程序,程序提示输入一个整数k,按下回车后,程序依次输出算法分组长度为5、7、3分别在数组长度为945800、10000、1000时的运行时间。
4.1 linearSelect_kth.cpp
- 第1大: 9999990
- 第5大: 9999940
- 第7大: 9999915
- 第90大: 9998974
- 第100大:9998835
4.2 linearSelect_kth_grouplenth_5_vs_7_vs_3.cpp
多次运行发现,在数组长度为945800时,基本上运行时间都是组长为7 < 组长为5 < 组长为3。在问题规模减小时,三者的运行时间大小关系略有波动,猜测可能是由于组长为3的select算法是非线性的以及程序运行计时存在误差等。