diff --git a/budd-common/src/main/java/io/github/ehlxr/algorithm/search/BinarySearch.java b/budd-common/src/main/java/io/github/ehlxr/algorithm/search/BinarySearch.java index 52ca52e..ff9ae2f 100644 --- a/budd-common/src/main/java/io/github/ehlxr/algorithm/search/BinarySearch.java +++ b/budd-common/src/main/java/io/github/ehlxr/algorithm/search/BinarySearch.java @@ -35,6 +35,12 @@ public class BinarySearch { System.out.println(searchRec(a, a.length, 3)); a = new int[]{1, 2, 2, 2, 2, 5, 6, 7, 8, 9}; System.out.println(searchFirst(a, a.length, 2)); + a = new int[]{1, 3, 4, 5, 6, 8, 8, 8, 9}; + System.out.println(searchLast(a, a.length, 8)); + a = new int[]{1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9}; + System.out.println(searchFirstGreater(a, a.length, 7)); + a = new int[]{1, 2, 3, 4, 5, 6, 8, 9}; + System.out.println(searchLastLess(a, a.length, 7)); } /** @@ -49,12 +55,10 @@ public class BinarySearch { int low = 0; int high = n - 1; while (low <= high) { - int mid = low + (high - low) >> 1; + int mid = low + ((high - low) >> 1); if (a[mid] == v) { return mid; - } - - if (a[mid] < v) { + } else if (a[mid] < v) { low = mid + 1; } else { high = mid - 1; @@ -83,12 +87,10 @@ public class BinarySearch { return -1; } - int mid = low + (high - low) >> 1; + int mid = low + ((high - low) >> 1); if (a[mid] == v) { return mid; - } - - if (a[mid] < v) { + } else if (a[mid] < v) { return searchRec(a, mid + 1, high, v); } else { return searchRec(a, low, mid - 1, v); @@ -107,16 +109,14 @@ public class BinarySearch { int low = 0; int high = n - 1; while (low <= high) { - int mid = low + (high - low) >> 1; + int mid = low + ((high - low) >> 1); if (a[mid] == v) { - if (a[mid - 1] != v) { + if (mid == 0 || a[mid - 1] != v) { return mid; } else { high = mid - 1; } - } - - if (a[mid] < v) { + } else if (a[mid] < v) { low = mid + 1; } else { high = mid - 1; @@ -124,4 +124,84 @@ public class BinarySearch { } return -1; } + + /** + * 查找最后一个值等于给定值的元素 + * + * @param a 要查找的数组 + * @param n 数组的长度 + * @param v 要查找的值 + * @return 要查找值在数组中的索引 + */ + public static int searchLast(int[] a, int n, int v) { + int low = 0; + int high = n - 1; + while (low <= high) { + int mid = low + ((high - low) >> 1); + if (a[mid] == v) { + if (mid == n - 1 || a[mid + 1] != v) { + return mid; + } else { + low = mid + 1; + } + } else if (a[mid] < v) { + low = mid + 1; + } else { + high = mid - 1; + } + } + return -1; + } + + /** + * 查找第一个大于等于给定值的元素 + * + * @param a 要查找的数组 + * @param n 数组的长度 + * @param v 要查找的值 + * @return 要查找值在数组中的索引 + */ + public static int searchFirstGreater(int[] a, int n, int v) { + int low = 0; + int high = n - 1; + while (low <= high) { + int mid = low + ((high - low) >> 1); + if (a[mid] < v) { + low = mid + 1; + } else { + if (mid == 0 || a[mid - 1] < v) { + return mid; + } else { + high = mid - 1; + } + } + } + return -1; + } + + /** + * 查找最后一个小于等于给定值的元素 + * + * @param a 要查找的数组 + * @param n 数组的长度 + * @param v 要查找的值 + * @return 要查找值在数组中的索引 + */ + public static int searchLastLess(int[] a, int n, int v) { + int low = 0; + int high = n - 1; + while (low <= high) { + int mid = low + ((high - low) >> 1); + if (a[mid] > v) { + high = mid - 1; + } else { + if (mid == n - 1 || a[mid + 1] > v) { + return mid; + } else { + low = mid + 1; + } + } + } + return -1; + } }