4.寻找两个正序数组的中位数

给定两个大小分别为 
m
 和 
n
 的正序(从小到大)数组 
nums1
 和 
nums2
。请你找出并返回这两个正序数组的 中位数 。算法的时间复杂度应该为 
O(log (m+n))
 。(暗示我们用二分法



输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2
 
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

一. 暴力解法

长度之和为奇数时,返回1个中位数;长度之和为偶数时,返回合并排序后中间两数的平均数。

如果不考虑复杂度的限制,可以想到先合并两个数组,再排序找到中位数。这个方法的弊端是浪费了“数组正序”的条件。时间复杂度取决于排序算法的时间复杂度,利用快速排序算法,它的时间复杂度是 4.寻找两个正序数组的中位数 。

合并两个有序数组,可以利用双指针比较元素大小,按序插入新的数组中。遍历完两个有序数组,可以得到一个更大的有序数组。归并排序的合并步骤,它的时间复杂度为 4.寻找两个正序数组的中位数 。

二. 二分查找

2.1 基于中位数的作用

在统计中,中位数被用来:将一个集合划分为两个长度相等的子集,其中一个子集中的元素总是大于另一个子集中的元素。中位数只跟分割线两边的元素有关,这道题可以转化成寻找两个正序数组中的分割线。

4.寻找两个正序数组的中位数

我们不必分别去确定分割线在两个数组中的位置,分割线两侧元素的数量是可以计算出来的。

4.寻找两个正序数组的中位数

4.寻找两个正序数组的中位数

需要额外确认:

第一个数组在分割线左边的最大值 小于 第二个数组在分割线右边的最小值。第二个数组在分割线左边的最大值 小于 第一个数组在分割线右边的最小值。

2.2 基于中位数的计算

根据中位数的定义,当 m+n 是奇数时,中位数是两个有序数组中的第 (m+n)/2 个元素,当 m+n 是偶数时,中位数是两个有序数组中的第 (m+n)/2 个元素和第 (m+n)/2+1 个元素的平均值。

因此,这道题可以转化成寻找两个有序数组中的第 k 小的数,其中 k 为 (m+n)/2 或 (m+n)/2+1。

寻找两个有序数组中的第 k 小的数,可以先比较 A[k/2−1] 和 B[k/2−1] 之后,可以排除 k/2 个不可能是第 k 小的数,查找范围缩小了一半。同时,我们将在排除后的新数组上继续进行二分查找,并且根据我们排除数的个数,减少 k 的值,这是因为我们排除的数都不大于第 k 小的数

有以下三种情况需要特殊处理:

如果 A[k/2−1] 或者 B[k/2−1] 越界,那么我们可以选取对应数组中的最后一个元素。在这种情况下,我们必须根据排除数的个数减少 k 的值,而不能直接将 k 减去 k/2。如果一个数组为空,说明该数组中的所有元素都被排除,我们可以直接返回另一个数组中第 k 小的元素。如果 k=1,我们只要返回两个数组首元素的最小值即可。

解法查询

基于中位数的计算 C



#define MIN(x, y) ((x) < (y) ? (x) : (y))
 
int getKthElement(int* nums1, int nums1Size, int* nums2, int nums2Size, int k){
    int index1=0, index2=0; //排除的数的坐标范围, < index 的部分
    while(true){
        if(index1 == nums1Size){
            return nums2[index2 + k -1];
        }
        if(index2 == nums2Size){
            return nums1[index1 + k -1];
        }
        if(k == 1){
            return MIN(nums1[index1], nums2[index2]);
        }
        
        int newIndex1 = MIN(index1 + k/2 -1, nums1Size-1);
        int newIndex2 = MIN(index2 + k/2 -1, nums2Size-1);
        if(nums1[newIndex1]<=nums2[newIndex2]){
            k -= newIndex1 - index1 + 1;
            index1 = newIndex1 +1;
        }else{
            k -= newIndex2 - index2 + 1;
            index2 = newIndex2 +1;
        }
    }
}
 
double findMedianSortedArrays(int* nums1, int nums1Size, int* nums2, int nums2Size) {
    int totalSize = nums1Size + nums2Size;
    if(totalSize%2 == 1){
        return getKthElement(nums1, nums1Size, nums2, nums2Size, (totalSize+1)/2);
    }else{
        return ( getKthElement(nums1, nums1Size, nums2, nums2Size, totalSize/2) 
        + getKthElement(nums1, nums1Size, nums2, nums2Size, totalSize/2 + 1) ) / 2.0;
    }
}

暴力解法 – 快速排序 C++

辅助数组法

这种快速排序实现方式使用一个辅助数组来存储划分后的结果,相比原地排序更直观易懂,但会使用额外的O(n)空间。



#include <iostream>
#include <vector>
 
using namespace std;
 
void quickSortWithAux(vector<int>& arr, vector<int>& aux, int low, int high) {
    if (low >= high) return;  // 递归终止条件
    
    int pivot = arr[high];    // 选择最后一个元素作为基准
    int left = low;           // 左指针 - 用于小于pivot的元素
    int right = high;         // 右指针 - 用于大于等于pivot的元素
    
    // 将划分结果存入辅助数组
    for (int i = low; i < high; i++) {
        if (arr[i] < pivot) {
            aux[left++] = arr[i];  // 小于基准的放左边
        } else {
            aux[right--] = arr[i];  // 大于等于基准的放右边
        }
    }
    
    aux[left] = pivot;  // 放置基准值
    
    // 将辅助数组的结果复制回原数组
    for (int i = low; i <= high; i++) {
        arr[i] = aux[i];
    }
    
    // 递归排序左右子数组
    quickSortWithAux(arr, aux, low, left - 1);
    quickSortWithAux(arr, aux, left + 1, high);
}
 
void quickSort(vector<int>& arr) {
    if (arr.size() <= 1) return;
    vector<int> aux(arr.size());  // 创建辅助数组
    quickSortWithAux(arr, aux, 0, arr.size() - 1);
}

 分区演示



//low=0, high=4
原数组: [3, 1, 4, 2, 5] (pivot=5)
分区过程:
aux数组变化:
[3, 0, 0, 0, 0]   // 3 < 5 → 放左边
[3, 1, 0, 0, 0]   // 1 < 5 → 放左边
[3, 1, 0, 4, 0]   // 4 < 5 → 放左边
[3, 1, 2, 4, 0]   // 2 < 5 → 放左边
最终:
[3, 1, 2, 4, 5]   // 放入基准值
//left=4, right=4
 
// 排序左半部分(小于基准的部分)
quickSortWithAux(arr, aux, 0, 3);
 
// 排序右半部分(大于基准的部分) 
quickSortWithAux(arr, aux, 5, 4);

 原地操作法



#include <iostream>
#include <vector>
 
using namespace std;
 
int partition(vector<int>& arr, int low, int high) {
    int pivot = arr[high];    // 选择最后一个元素作为基准
    int i = low - 1;          // 指向小于基准的区域的末尾
    
    for (int j = low; j < high; j++) {
        if (arr[j] < pivot) {
            i++;
            swap(arr[i], arr[j]);  // 将小于基准的元素交换到前面
        }
    }
    
    swap(arr[i + 1], arr[high]);  // 将基准放到正确位置
    return i + 1;                 // 返回基准位置
}
 
void quickSort(vector<int>& arr, int low, int high) {
    if (low < high) {
        // 划分数组并获取基准位置
        int pivotIndex = partition(arr, low, high);
        
        // 递归排序左右子数组
        quickSort(arr, low, pivotIndex - 1);
        quickSort(arr, pivotIndex + 1, high);
    }
}
 
void quickSort(vector<int>& arr) {
    if (arr.size() <= 1) return;
    quickSort(arr, 0, arr.size() - 1);
}

暴力解法 – 归并排序 C++

借鉴归并排序的合并步骤:

使用双指针分别遍历两个数组比较指针所指元素,将较小的放入合并数组当一个数组遍历完后,将另一个数组剩余元素直接加入



#include <vector>
#include <iostream>
 
using namespace std;
 
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
    int m = nums1.size(), n = nums2.size();
    vector<int> merged;
    int i = 0, j = 0;
    
    // 归并两个有序数组
    while (i < m && j < n) {
        if (nums1[i] <= nums2[j]) {
            merged.push_back(nums1[i++]);
        } else {
            merged.push_back(nums2[j++]);
        }
    }
    
    // 处理剩余元素
    while (i < m) merged.push_back(nums1[i++]);
    while (j < n) merged.push_back(nums2[j++]);
    
    // 计算中位数
    int total = merged.size();
    if (total % 2 == 1) {
        return merged[total / 2];
    } else {
        return (merged[total / 2 - 1] + merged[total / 2]) / 2.0;
    }
}

基于中位数的作用 C++



class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        if (nums1.size() > nums2.size()) {
            return findMedianSortedArrays(nums2, nums1);
        }
 
        int m = nums1.size();
        int n = nums2.size();
 
        int left = 0, right = m;
 
        while(left < right){
            int i = left + (right - left + 1) / 2;
            int j = (m + n + 1) / 2 - i;
 
            if (nums1[i-1] > nums2[j]) {
                right = i - 1;
            } else {
                left = i;
            }
        }
 
        int i = left;
        int j = (m + n + 1) / 2 - i;
        
        int nums1LeftMax = (i == 0 ? INT_MIN : nums1[i - 1]);
        int nums1RightMin = (i == m ? INT_MAX : nums1[i]);
        int nums2LeftMax = (j == 0 ? INT_MIN : nums2[j - 1]);
        int nums2RightMin = (j == n ? INT_MAX : nums2[j]);
 
        int medianLeft = max(nums1LeftMax, nums2LeftMax);
        int medianRight = min(nums1RightMin, nums2RightMin);
 
        return (m + n) % 2 == 0 ? (medianLeft + medianRight) / 2.0 : medianLeft;
    }
};

基于中位数的作用 Java



class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        if (nums1.length > nums2.length) {
            return findMedianSortedArrays(nums2, nums1);
        }
 
        int m = nums1.length;
        int n = nums2.length;
 
        int left = 0, right = m;
 
        while (left < right) {
            int i = left + (right - left + 1) / 2;
            int j = (m + n + 1) / 2 - i;
 
            if (nums1[i - 1] > nums2[j]) {
                right = i - 1;
            } else {
                left = i;
            }
        }
 
        int i = left;
        int j = (m + n + 1) / 2 - i;
        
        int nums1LeftMax = (i == 0 ? Integer.MIN_VALUE : nums1[i - 1]);
        int nums1RightMin = (i == m ? Integer.MAX_VALUE : nums1[i]);
        int nums2LeftMax = (j == 0 ? Integer.MIN_VALUE : nums2[j - 1]);
        int nums2RightMin = (j == n ? Integer.MAX_VALUE : nums2[j]);
 
        int medianLeft = Math.max(nums1LeftMax, nums2LeftMax);
        int medianRight = Math.min(nums1RightMin, nums2RightMin);
 
        return (m + n) % 2 == 0 ? (medianLeft + medianRight) / 2.0 : medianLeft;
    }
}

基于中位数的作用 Python3



class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        if len(nums1) > len(nums2):
            return self.findMedianSortedArrays(nums2, nums1)
 
        m, n = len(nums1), len(nums2)
        
        left, right = 0, m
        
        while left < right:
            i = left + (right - left + 1) // 2
            j = (m + n + 1) // 2 - i
            
            if nums1[i - 1] > nums2[j]:
                right = i - 1
            else:
                left = i
        
        i = left
        j = (m + n + 1) // 2 - i
        
        nums1_left_max = -sys.maxsize - 1 if i == 0 else nums1[i - 1]
        nums1_right_min = sys.maxsize if i == m else nums1[i]
        nums2_left_max = -sys.maxsize - 1 if j == 0 else nums2[j - 1]
        nums2_right_min = sys.maxsize if j == n else nums2[j]
        
        median_left = max(nums1_left_max, nums2_left_max)
        median_right = min(nums1_right_min, nums2_right_min)
        
        if (m + n) % 2 == 0:
            return (median_left + median_right) / 2.0
        else:
            return median_left

© 版权声明

相关文章

暂无评论

none
暂无评论...