Convenient and near-optimal binary search on floating point numbers
This post was originally written on Codeforces; relevant discussion can be found here. TL;DR Use the following template (C++20) for efficient and near-optimal binary search (in terms of number of queries) on floating point numbers. Template template <std::size_t N_BITS> using int_least_t = std::conditional_t< N_BITS <= 8, std::uint8_t, std::conditional_t< N_BITS <= 16, std::uint16_t, std::conditional_t< N_BITS <= 32, std::uint32_t, std::conditional_t< N_BITS <= 64, std::uint64_t, std::conditional_t<N_BITS <= 128, __uint128_t, void>>>>>; // this should work for float and doubles, but for long doubles, std::bit_cast will fail on most systems due to being 80 bits wide. // to handle this, consider using doubles instead or std::bit_cast the long double to an 80-bit bitset and convert it to a 128 bit integer using to_ullong. /* * returns first x in [a, b] such that predicate(x) is false, conditioned on * logical_predicate(a) && !logical_predicate(b) && logical_predicate(-inf) && * !logical_predicate(inf) * here logical_predicate is the mathematical value of the predicate, not the * machine value of the predicate * it is guaranteed that non-nan, non-inf inputs are passed into the predicate * if NaNs or infinities are passed to this function as argument, then the * inputs to the predicate will start from smallest/largest representable * floating point numbers of the input type - this can be a source of errors * if you multiply the input by something > 1 for example * strictly speaking, the predicate should also be perfectly monotonic, but if * it gives out-of-order booleans in some small range [a, a + eps] (and the * correct order elsewhere), then the answer will be somewhere in between * the same holds for how denormals are handled by this code */ // template <bool check_infinities = false, bool distinguish_plus_minus_zero = false, bool deal_with_nans_and_infs = false, std::floating_point T> T partition_point_fp(T a, T b, auto&& predicate) { static constexpr std::size_t T_WIDTH = sizeof(T) * CHAR_BIT; using Int = int_least_t<T_WIDTH>; static constexpr auto is_negative = [](T x) { return static_cast<bool>((std::bit_cast<Int>(x) >> (T_WIDTH - 1)) & 1); }; if constexpr (distinguish_plus_minus_zero) { if (a == T(0.0) && b == T(0.0) && is_negative(a) && !is_negative(b)) { if (!predicate(-T(0.0))) { return -T(0.0); } else { // predicate(0.0) is guaranteed to be true because b = 0.0 return T(0.0); } } } if (a >= b) return NAN; if constexpr (deal_with_nans_and_infs) { // get rid of NaNs as soon as possible if (std::isnan(a)) a = -std::numeric_limits<T>::infinity(); if (std::isnan(b)) b = std::numeric_limits<T>::infinity(); // deal with infinities if (a == -std::numeric_limits<T>::infinity()) { if constexpr (check_infinities) { if (predicate(-std::numeric_limits<T>::max())) { a = -std::numeric_limits<T>::max(); } else { return -std::numeric_limits<T>::max(); } } else { a = -std::numeric_limits<T>::max(); } } if (b == std::numeric_limits<T>::infinity()) { if constexpr (check_infinities) { if (!predicate(std::numeric_limits<T>::max())) { b = std::numeric_limits<T>::max(); } else { return std::numeric_limits<T>::infinity(); } } else { b = std::numeric_limits<T>::max(); } } } // now a and b are both finite - deal with differently signed a and b if (is_negative(a) && !is_negative(b)) { // check 0 once if constexpr (distinguish_plus_minus_zero) { if (!predicate(-T(0.0))) { b = -T(0.0); } else if (predicate(T(0.0))) { a = T(0.0); } else { return T(0.0); } } else { if (!predicate(T(0.0))) { b = -T(0.0); } else { a = T(0.0); } } } // in the case a and b are both 0 after the above check, return 0 if (a == b) return T(0.0); // start actual binary search auto get_int = [](T x) { return std::bit_cast<Int, T>(x); }; auto get_float = [](Int x) { return std::bit_cast<T, Int>(x); }; if (b > 0) { while (get_int(a) + 1 < get_int(b)) { auto m = std::midpoint(get_int(a), get_int(b)); if (predicate(get_float(m))) { a = get_float(m); } else { b = get_float(m); } } } else { while (get_int(-b) + 1 < get_int(-a)) { auto m = std::midpoint(get_int(-b), get_int(-a)); if (predicate(-get_float(m))) { a = -get_float(m); } else { b = -get_float(m); } } } return b; } It is also possible to extend this to breaking early when a custom closeness predicate is true (for example, min(absolute error, relative error) < 1e-9 and so on), but for the sake of simplicity, this template does not do so. ...