from dataclasses import dataclass


class Probe:
    def __init__(self, rtree, path, guess_depth=0):
        """Test an rtree with a path, and optional number of guesses"""
        self.path = path + [True] * guess_depth
        self.depth = guess_depth
        self.reasons = []

        def check(reason: str):
            """Check the path or default to false"""
            self.reasons.append(reason)
            return self.undecided() <= 0 and self.path[len(self.reasons) - 1]

        self.input = rtree(check)

    def undecided(self):
        """The number of choices left on the path"""
        return len(self.reasons) - len(self.path)


def reduce(predicate, rtree):
    # Extract the initial reduction probe, from the rightmost branch.
    rp = Probe(rtree, [])
    # Run exponential search after the depest sequence of trues
    # that can be appended to the path without failing the predicate
    depth = 1
    # Invariant: predicate(rp) == True
    while rp.undecided() > 0:
        # Try to probe the with current path extended by trues trues.
        if predicate(rp_ := Probe(rtree, rp.path, depth)):
            rp, depth = rp_, depth * 2
            continue
        # Adjust the depth so that none of them are wasted.
        depth += min(0, rp_.undecided())
        # Perform a binary search, accepting the nessarary trues, and
        # reducing the unknown trues to 1:
        while depth > 1:
            # Prope the rtree with the current path extended by half the depth.
            if predicate(rp_ := Probe(rtree, rp.path, depth // 2)):
                rp = rp_  # Accept the current path.
                depth -= depth // 2  # And try the remaining half
            else:
                depth //= 2  # Half the current trues
        # Store that the next element in the path has to be false
        rp.path.append(False)
    # return the input.
    return rp


def debug(predicate):
    counter = 0

    def newpred(rp):
        nonlocal counter
        counter += 1
        t = predicate(rp.input)
        print(
            f"{counter:02})",
            ", ".join(rp.reasons[len(rp.path) - rp.depth : len(rp.path)]),
        )
        print(f"... P({rp.input!r}) = {t}")
        return t

    return newpred


def latex(
    predicate,
    query_format="Remove {}?".format,
    input_format="\\verb|{}|".format,
    start_count=0,
):
    counter = start_count - 1

    def newpred(rp):
        nonlocal counter
        counter += 1
        t = predicate(rp.input)
        query = ", ".join(rp.reasons[len(rp.path) - rp.depth : len(rp.path)])
        theck = "true" if t else "false"

        print(
            f"{counter:02} & \\verb|{pretty(rp)}| & {query_format(query)} & {input_format(rp.input)} & {theck} \\\\"
        )
        return t

    return newpred


def pretty(rp):
    from itertools import zip_longest

    def binary(a):
        return "1" if a else "0"

    return "".join(
        a for a, _ in zip_longest(map(binary, rp.path), rp.reasons, fillvalue="*")
    )


def reduce_abc(check) -> str:
    result = ""
    for x in "abcdefghijklmnopqrstuvxyz":
        if not check(f"{x}"):
            result += x
        else:
            result += " "
    return result


if __name__ == "__main__":
    p = latex(
        lambda e: "a" in e or "p" in e,
        start_count=0,
    )
    input_format = "\\verb|{}|".format
    p(Probe(reduce_abc, []))
    rp = reduce(p, reduce_abc)
    print(f"&& {input_format(rp.input)} & true \\\\")