二分法的lower_bound和upper_bound

in 算法 with 0 comment

以前写算法的时候一般都是直接用STL中的upper_boundlower_bound,导致有的时候真正手写二分时会有边界不清楚的情况,所以索性还是自己实现一下,STL的做法是对数组长度进行二分,这里是对左右边界进行二分,结果一样的。

注意取low + high可能会溢出。

用c++简单写了下4种情况。

#include<iostream>
#include<vector>
using namespace std;

/// STL upper_bound
/// 第一个>key的元素的下标,下标为数组大小,表示不存在。
int upper_bound(vector<int>& arr, int key) {
    int low = 0, high = arr.size();
    int mid = 0;
    while(low < high) {
        mid = low + ((high - low) >> 1);
        // key在[mid + 1,high],即使相等,也不是mid,因为是>
        if(arr[mid] <= key) {
            low = mid + 1;
        }
        // key在[low, mid],右边界不动。让左边界逼近即可。
        else {
            high = mid;
        }
    }
    // 最后必定是low=high, 如果找不到low>high
    return low;
}

/// STL lower_bound
/// 第一个>=key的元素的下标,下标为数组大小,表示不存在。
int lower_bound(vector<int>& arr, int key) {
    int low = 0, high = arr.size();
    int mid = 0;
    while(low < high) {
        mid = low + ((high - low) >> 1);
        // key在[mid + 1,high],不能取等,因为目标可能是取等的arr[mid]。
        // 假定mid=low,取等将导致low>mid,而算错区间。
        if(arr[mid] < key) {
            low = mid + 1;
        }
        // key在[low, mid]。这里取等没影响,因为high=mid,即使答案是mid也在区间内
        else {
            high = mid;
        }
    }
    // 最后必定是low=high, 如果找不到low>high
    return low;
}

/// 变种1 最后一个<key的元素
/// 就是low_bound的前一个,为-1时表示所有值都比key小
int last_less_key(vector<int>& arr, int key)
{
    return lower_bound(arr, key) - 1;
}

/// 变种2 最后一个<=key的元素
/// 就是upper_bound的前一个,为-1时表示所有值都比key小
int last_lessORequal_key(vector<int>& arr, int key)
{
    return upper_bound(arr, key) - 1;
}

int main() {
    vector<int> arr{1,2,2,2,3,3,6,8};
    /// 测试 1
    cout << "lower_bound" << endl;
    cout << lower_bound(arr, 0) << endl;
    cout << lower_bound(arr, 2) << endl;
    cout << lower_bound(arr, 4) << endl;
    cout << lower_bound(arr, 9) << endl;

    /// 测试 2
    cout << "upper_bound" << endl;
    cout << upper_bound(arr, 0) << endl;
    cout << upper_bound(arr, 2) << endl;
    cout << upper_bound(arr, 4) << endl;
    cout << upper_bound(arr, 9) << endl;
}
Responses