二分查找

September 18, 2020

Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky.

—Donald Knuth. The Art of Computer Programming.

二分查找是一种理解起来非常简单,但不管是初学算法还是有一定基础的人都很容易掉进陷阱里的算法。

直觉 #

首先介绍一个猜数字的游戏,从1~100一共100个数字中任取一个数字target,竞猜者每猜一个数字都会得到“太小”,“太大”,“猜对了”这三种反馈,现要求竞猜者尽可能用最少的次数猜中这个数字。不少人应该能想到,先去猜50,如果太小就去选右边区间的中位数75,如果太大则去选25,依次类推,这样的算法的时间复杂度为 \(\Omicron(logN)\) ,二分查找这样就从直觉上理解了。

事实上二分查找是试错法(trail and error)一种应用,通过在搜索空间中进行猜测,再通过猜测的结果缩小搜索空间,最后搜索空间收敛便得到我们需要的答案,这篇LeetCode上的帖子对这种思想有着极好的阐述。如果在验证猜测结果后只是将猜测结果从搜索空间中排除出去,那么试错法便会特化为常说的暴力解法。

好的,先写出这个猜数字游戏的代码:

// guess(int x) is the API that returns 1 if x greater than answer, 0 if equal and -1 if less.

int guessNumber() {
    int lo = 1, hi = 100;
    // 1. exit condition
    while (lo <= hi) {
        // 2. how to compute middle value
        int mid = (lo + hi) / 2;
        // 3. how to move boundaries lo or hi
        if (guess(mid) > 0) hi = mid - 1;
        else if (guess(mid) < 0) lo = mid + 1;
        else return mid;
    }
    // impossible
    return -1;
}

如果猜的数mid小于答案,那么我们把lo移到mid+1这样以后猜的数不会比这更小,同理把hi移到mid-1,如果mid命中结果直接返回,这样写出来的二分查找初学者理解起来非常容易,也很直观看出来它最后总会收敛退出循环。

根据二分查找退出循环条件,二分查找的写法可以分为三种,而这三种写法我打算用区间分析法的思想来进行命名。

开区间和开区间 #

希望看这篇文章的读者没有忘掉初中学的关于函数区间的知识,即使忘了也没关系自行搜索一下快速了解开区间和闭区间的概念。

为了与后面保持一致,上面的写法可以改写为:

int guessNumber() {
    int lo = 1, hi = 100;
    while (lo <= hi) {
        int mid = (lo + hi) / 2;
        if (guess(mid) > 0) hi = mid - 1;
        else lo = mid + 1;
    }
    return hi;
}

这么写看起来有点古怪,现在开始引入区间分析法。假如我们把1到100当成一个连续单调增的搜索区间,在区间两端分别插入哨兵 \(-\infty\) \(\infty\) ,如下图所示:

img

如果像上图一样把左端到lohi到右端的区间分别想象成分别具有某种属性的区间,而我们要寻找的答案正好位于这两个区间的交汇点,那么这两个区间分别是什么样的呢?由退出条件lo <= hi可以知道当循环退出时会满足关系式hi + 1 == lo,如果要保证最后两个区间完全覆盖整个搜索区间,那么这两个区间一个为 \(\lbrack-\infty,lo)\) 的左闭右开区间,另一个为 \((hi, +\infty\rbrack\) 的左开右闭区间,如果只关注lohi的话那么这两个区间都为开区间。而通过上述代码片段可以看出来 \(\lbrack-\infty,lo)\) 中的所有数满足小于等于答案,而 \((hi, +\infty\rbrack\) 中的数则大于答案。一开始初始值的设定就可以搞清楚了,在没有搜索开始时, \(\lbrack-\infty, 1)\) \((100, +\infty\rbrack\) 时不会违反上述的属性。而每次搜索得出mid的属性之后,假如guess(mid) <= 0那么就令lo = mid + 1mid以及左边的所有值都包含在lo所在的开区间,而当guess(mid) > 0时同理让hi = mid - 1

当搜索结束时,由于答案一定存在和整个区间为单调增,所以最终结果一定为lo - 1也即为hi这个值。比如上图假设答案为98。

如何计算mid #

mid的计算其实也是存在陷阱的。如果选取的lohi足够大,那么(lo + hi) / 2的计算过程是有可能产生溢出的。为了解决这个问题,可以把lo + hi的结果进行无符号右移,这样即使溢出成为负数由于逻辑右移不会在符号位上进行补1,自然结果不会出错。

那么把

int mid = (lo + hi) / 2

修改为

int mid = (lo + hi) >>> 2

有趣的是Java实现的二分查找里这个bug一直存在了10年之久,存在于java.util.Arrays.binarySearch里面。感兴趣的可以看看这个前G家工程师写的文章

另外,int mid = (lo + hi) >>> 2是可以写为int mid = (lo + hi + 1) >>> 2,在上面的写法它们是等价的。但如果换到下面的开闭区间的写法,那么你就要小心了。

开区间和闭区间 #

如果我们再把上面的写法变一变:

int guessNumber() {
    int lo = 1, hi = 101;
    while (lo < hi) {
        int mid = (lo + hi) >>> 2;
        if (guess(mid) >= 0) hi = mid;
        else lo = mid + 1;
    }
    // return either lo or hi
    return hi;
}

这个时候我们发现退出条件变为lo < hi,运用上述的区间分析法可以画出另一幅图:

img

嗯,应该很快就能看出来,当循环结束时满足关系式lo == hi,所以如果两个区间应该一个为开区间,一个为闭区间,这样把覆盖所有的区间。左边区间内的数小于答案,而右边区间则大于等于答案,又由于右边区间是闭区间所以右边区间的左端点值就是我们要求的答案,听起来好像挺拗口。

可能你会问到,凭什么左边一定要是开区间,右边是闭区间,颠倒过来可不可以?

答案是不行的,否则这个程序就会成为死循环。问题出在上面所讲的mid中间值的计算方式上面。

假设程序是对的,当它退出的前一刻,lohi值是相邻分布的,如图所示:

img

感兴趣的读者可以自证一下,此时int mid = (lo + hi) / 2计算之后mid的取值永远落在lo上面。假设此时要移动lo来扩张左边区间,但很不幸地是,如果左边区间为闭区间,lo = mid只会使得整个过程再来一遍。所以左边区间必须作为开区间来完成最后的收敛,如果我们要左边是闭区间并且区间的右端点是答案的话那么代码就要改写为下面这样:

int guessNumber() {
    int lo = 0, hi = 100;
    while (lo < hi) {
        int mid = (lo + hi + 1) >>> 2;
        if (guess(mid) > 0) hi = mid - 1;
        else lo = mid;
    }
    // return either lo or hi
    return hi;
}

一般来讲,开开区间的写法一般用于精确值的查找,而开闭区间基本适用于一般情况。比如给出一个数n,让你找到在一个递增数组中是否存在,如果存在就返回索引,如果不存在则返回应该插入的索引。实际上这个问题就是让你以右边区间满足大于等于n,返回右边闭区间的左端点的索引。

int find(int n, int[] arr) {
    int lo = 0, hi = arr.length;
    while (lo < hi) {
        int mid = (lo + hi) >>> 1;
        if (arr[mid] < n) lo = mid + 1;
        else hi = mid;
    }
    return hi;
}

是不是很容易就能写出满足题意的代码?

闭区间和闭区间 #

第三种写法是这样的:

int guessNumber() {
    int lo = 0, hi = 101;
    while (lo + 1 < hi) {
        int mid = (lo + hi) / 2;
        if (guess(mid) > 0) hi = mid;
        else lo = mid;
    }
    return lo;
}

估计你的表情应该是这样的:

别说是你了,就是很多面试官也没见过这种写法,但它还真就是对的。运用区间分析法再画一张图:

img

可以看出来当循环终止时满足关系式lo + 1 == hi,那么为了覆盖整个区间左右两边都必须为闭区间,结合上面的分析你应该很快明白过来这代码为什么这么写。

不变量 #

但是区间分析法其实要求左右两个区间满足不同性质,这在某些场合是并不适用的。比如在《调试九法》中David曾经提到用二分排查法寻找电话线的故障,一捆电话线每次均分为两捆,如果其中一捆出现问题再到这捆中再进行这一过程。你能说被均分的两捆电话线除了有问题的那根之外有什么性质上的差异?

来看一道题,假设一个数组[1,2,5,3,8,6,7,4]中寻找任意峰值,峰值的定义是该数比左右相邻的两数都大,只要返回任意一个峰值的索引即可,数组中的数不会重复且保证总有一个峰值存在,两端视为插入哨兵负无穷。

这个题如果用二分查找来解的话,实在看不出左区间和右区间有什么性质上的差异呀,因为只要返回任意一个峰值,那么意味着两边的区间都可能包含不止一个的峰值。既然两边区间不管用,但是区间的交汇点保证能带来答案的话,那么搜索过程两区间中的未搜索区间满足始终有峰值的假设应该是可行的,那么这个假设就叫做不变量。

int findPeak(int[] arr) {
    int lo = 0, hi = arr.length - 1;
    while (lo < hi) {
        int mid = (lo + hi) >>> 1;
        if (arr[mid] < arr[mid + 1]) hi = mid;
        else lo = mid + 1;
    }
    return lo;
}

这里面假设不变量:峰值一直存在于[lo, hi]中:

  1. 初始状态时,这一假设是满足的;
  2. 在每次搜索开始时,假设上一次满足该假设,那么若求得的arr[mid] < arr[mid + 1],就说明mid有可能是峰值,将其纳入右边的闭区间;否则就不是峰值,排除出去,包含于左边的开区间;
  3. 等循环退出时,由于不变量一直满足,此时lo == hi,那么hi或者lo就是要求的峰值索引。

例题 #

旋转数组搜索 #

假设某个递增数组arr在某个索引进行旋转操作,例如:

[0,1,2,4,5,6,7]变为[4,5,6,7,0,1,2]

该数组保证不包含重复值,给定一个目标值若存在让你返回在数组中的索引,否则返回-1

这道题很难区分左右区间有什么不同属性,所以在这里设定一个不变量,整个数组被轴枢划分为两个递增区域,假设[lo, hi]这个闭区间内存在目标值,假如这个数组存在目标值,于是有:

  1. 显然初始状态满足这个不变量;
  2. 如果arr[mid] <= arr[hi]就说明midhi在同一个递增区域中,考虑一般情况若·lo位于左边的递增区域,midhi位于右边的递增区域,那么再把lo移到mid + 1或者hi移到mid的过程中,显然目标值在(mid, hi]时将lo移到mid + 1的情况更简单,此时只需保证target <= nums[hi] && target > nums[mid];若arr[mid] > arr[hi],则说明midhi不在同一个递增区域中,考虑一般情况,lomid位于左边的递增区域,那么如果target[lo, mid]时,将hi移到mid;
  3. 循环退出时,lo == hi标志着lo值索引即为所求答案。

写成代码为:

int searchInRotateArray(int[] arr, int target) {
    int lo = 0, hi = arr.length - 1;
    while (lo < hi) {
        int mid = (lo + hi) >>> 1;
        if (arr[mid] <= arr[hi]) {
            if (target > arr[mid] && target <= nums[hi]) lo = mid + 1;
            else hi = mid;
        } else {
            if (target <= arr[mid] && target > nums[hi]) hi = mid;
            else lo = mid +1;
        }
    }
    return nums.length == 0 ? -1 : nums[lo] == target ? lo : -1;
}

如果数组存在有重复值,此时当arr[mid] == arr[hi]时无法确定midhi是否在同一递增区域,需要单独拿出来讨论,假设target == arr[hi]那么这个值就找到了,否则在不变量成立的前提下将hi从区间剔除,维持不变量成立。

代码如下:

int searchInRotateArrayWithDuplicates(int[] arr, int target) {
    int lo = 0, hi = arr.length - 1;
    while (lo < hi) {
        int mid = (lo + hi) >>> 1;
        if (arr[mid] < arr[hi]) {
            if (target > arr[mid] && target <= nums[hi]) lo = mid + 1;
            else hi = mid;
        } else if (arr[mid] > arr[hi]) {
            if (target <= arr[mid] && target > nums[hi]) hi = mid;
            else lo = mid +1;
        } else {
        	if (target == arr[hi]) return hi;
        	hi--;
        }
    }
    return nums.length == 0 ? -1 : nums[lo] == target ? lo : -1;
}

带有黑名单的随机数挑选 #

假设有一个整数序列[0, N)和一个黑名单数组B,黑名单数组的数字不能被挑选。现在要求你撰写一个方法使得挑选出来的数字不在黑名单且满足随机分布。

估计题意很容易看懂也能哈希表实现出来这样的要求,但如果说让你用二分查找去实现这个要求,那可能有点难。

现在想象一下,假设黑名单数组按增序排列,且原数组长度为N,再将长度为M的黑名单数组挤出序列后,后面的数字依次补缺,那么对于B[i] < x < B[i + 1]而言,有关系式x = index + i + 1,其中indexx的索引,则index + d + 1 > B[i] + d - i >= B[d], i > d

故代码可写为:

class Solution {
    int[] b;
    int len;
    Random rand;

    public Solution(int N, int[] blacklist) {
        len = N - blacklist.length;
        b = blacklist;
        Arrays.sort(b);
        rand = new Random();
    }
    
    public int pick() {
        int index = rand.nextInt(len);
        int lo = -1, hi = b.length - 1;
        while (lo < hi) {
            int mid = (lo + hi + 1) >>> 1;
            int guess = index + mid + 1;
            if (guess > b[mid]) lo = mid;
            else hi = mid - 1;
        }
        return index + lo + 1;
    }
}