from dataclasses import dataclass
import math
from typing import TypeVar


def reduce_simp(predicate, rtree):
    path: list[bool] = []
    labels: list[str] = []

    def check(label):
        labels.append(label)
        return path[len(labels) - 1] if len(path) >= len(labels) else False

    i = rtree(check)
    while len(path) < len(labels):
        path.append(True)
        labels.clear()
        if t := predicate(j := rtree(check)):
            i = j
        else:
            path[-1] = False
        print(f"{labels[len(path) - 1]} ... test {j!r:<10} ... {t}")
    return i


class Probe:
    def __init__(self, rtree, path, guess_depth=0):
        """Test an rtree with a path, and optional number of guesses"""
        self.path = path + [True] * guess_depth
        self.depth = guess_depth
        self.reasons = []

        def check(reason: str):
            """Check the path or default to false"""
            self.reasons.append(reason)
            return self.undecided() <= 0 and self.path[len(self.reasons) - 1]

        self.input = rtree(check)

    def undecided(self):
        """The number of choices left on the path"""
        return len(self.reasons) - len(self.path)


def reduce1(predicate, rtree):
    # Extract the initial reduction probe, from the rightmost branch.
    rp = Probe(rtree, [])
    # Run exponential search after the depest sequence of trues
    # that can be appended to the path without failing the predicate
    depth = 1
    # Invariant: predicate(rp) == True
    while rp.undecided() > 0:
        # Try to probe the with current path extended by one trues
        if predicate(rp_ := Probe(rtree, rp.path, depth)):
            rp = rp_
            continue
        rp.path.append(False)
    # return the input.
    return rp


def reduce(predicate, rtree):
    # Extract the initial reduction probe, from the rightmost branch.
    rp = Probe(rtree, [])
    # Run exponential search after the depest sequence of trues
    # that can be appended to the path without failing the predicate
    depth = 1
    # Invariant: predicate(rp) == True
    while rp.undecided() > 0:
        # Try to probe the with current path extended by trues trues.
        if predicate(rp_ := Probe(rtree, rp.path, depth)):
            rp, depth = rp_, depth * 2
            continue
        # Adjust the depth so that none of them are wasted.
        depth += min(0, rp_.undecided())
        # Perform a binary search, accepting the nessarary trues, and
        # reducing the unknown trues to 1:
        while depth > 1:
            # Prope the rtree with the current path extended by half the depth.
            if predicate(rp_ := Probe(rtree, rp.path, depth // 2)):
                rp = rp_  # Accept the current path.
                depth -= depth // 2  # And try the remaining half
            else:
                depth //= 2  # Half the current trues
        # Store that the next element in the path has to be false
        rp.path.append(False)
    # return the input.
    return rp


def debug(predicate):
    counter = 0

    def newpred(rp):
        nonlocal counter
        counter += 1
        t = predicate(rp.input)
        print(
            f"{counter:02})",
            ", ".join(rp.reasons[len(rp.path) - rp.depth : len(rp.path)]),
        )
        print(f"... P({rp.input!r}) = {t}")
        return t

    return newpred


def latex(
    predicate,
    query_format="Remove {}?".format,
    input_format="\\verb|{}|".format,
    start_count=0,
):
    counter = start_count - 1

    def newpred(rp):
        nonlocal counter
        counter += 1
        t = predicate(rp.input)
        query = ", ".join(rp.reasons[len(rp.path) - rp.depth : len(rp.path)])
        theck = "true" if t else "false"

        print(
            f"{counter:02} & \\verb|{pretty(rp)}| & {query_format(query)} & {input_format(rp.input)} & {theck} \\\\"
        )
        return t

    return newpred


def table(
    predicate,
    input_format="{}".format,
    query_format=str,
    start_count=0,
):
    counter = start_count - 1

    def newpred(rp):
        nonlocal counter
        counter += 1
        t = predicate(rp.input)
        query = ", ".join(
            query_format(a) for a in rp.reasons[len(rp.path) - rp.depth : len(rp.path)]
        )
        theck = "true " if t else "false"

        print(f"{counter:02} - {input_format(rp.input)} - {theck} - {query}")
        return t

    return newpred


def pretty(rp):
    from itertools import zip_longest

    def binary(a):
        return "1" if a else "0"

    return "".join(
        a if b != "*" else "!"
        for a, b in zip_longest(map(binary, rp.path), rp.reasons, fillvalue="*")
    )


def reduce_abc(check) -> str:
    result = ""
    for x in "abc":
        if not check(f"remove {x}?"):
            result += x
        else:
            result += " "
    return result


def reduce_dd2(c: list, check) -> list:
    if check(f"ignore {c}?"):
        return []
    if len(c) == 1:
        return c
    pivot = len(c) // 2
    c2 = reduce_dd(c[pivot:], check)
    c1 = reduce_dd(c[:pivot], check)
    return c1 + c2


def reduce_dd(c: list, check) -> list:
    if len(c) == 1:
        return c
    pivot = len(c) // 2
    c1 = c[:pivot]
    c2 = c[pivot:]
    if check(f"result in c1: {c1}?"):
        return reduce_dd(c1, check)
    elif check(f"result in c2: {c2}?"):
        return reduce_dd(c2, check)
    else:
        return reduce_dd(c1, check) + reduce_dd(c2, check)


I = TypeVar("I")


def reduce_df(i: I, actions, check) -> I:
    j = None
    while i != j:
        for act in actions:
            if check(f"apply {act}?"):
                i, j = act(i), i
        else:
            break
    return i


def reduce_ddmin(input: list, check, n=2) -> list:
    def find_next(input, n):
        step = math.ceil(len(input) / n)
        subsets = [(step * i, step * (i + 1)) for i in range(n)]
        for i, (f, t) in enumerate(subsets):
            if check(f"delta: {i}/{n}"):
                return input[f:t], 2

        for i, (f, t) in enumerate(subsets) if n != 2 else []:
            if check(f"complement: {i}/{n}"):
                return input[:f] + input[t:], max(n - 1, 2)

        return input, n * 2

    while len(input) > 1 and n < 2 * len(input):
        input, n = find_next(input, n)

    return input


def test_reduce_dd():
    from functools import partial

    p = debug(lambda a: 3 in a and 6 in a)
    rp = reduce(p, partial(reduce_dd, [1, 2, 3, 4, 5, 6, 7, 8]))
    print(rp)


def test_reduce_ddmin():
    from functools import partial

    input_format = lambda a: "".join("X" if i in a else "." for i in range(1, 9))
    p = table(
        lambda a: 1 in a and 7 in a and 8 in a,
        input_format=input_format,
        query_format=str,
    )
    rp = reduce1(p, partial(reduce_ddmin, [1, 2, 3, 4, 5, 6, 7, 8]))
    print(f"   - {input_format(rp.input)}")


if __name__ == "__main__":
    p = latex(
        lambda e: "a" in e or "p" in e,
        start_count=0,
    )
    input_format = "\\verb|{}|".format
    p(Probe(reduce_abc, []))
    rp = reduce(p, reduce_abc)
    print(f"&& {input_format(rp.input)} & true \\\\")