from dataclasses import dataclass
from functools import partial

import rtree


@dataclass
class Let:
    var: str
    assignee: "Expr"
    body: "Expr"

    def replace(self, reps):
        return Let(
            self.var,
            self.assignee.replace(reps),
            self.body.replace(reps),
        )

    def eval(self, vals=None):
        vals = vals or dict()
        r = self.assignee.eval(vals)
        vals = vals.copy()
        vals[self.var] = r
        return self.body.eval(vals)

    def subterms(self):
        yield self
        yield from self.assignee.subterms()
        yield from self.body.subterms()

    def pretty(self, prec=0):
        p = f"{self.var} := {self.assignee.pretty(2)}; {self.body.pretty(1)}"
        if prec >= 1:
            p = f"({p})"
        return p


@dataclass
class Add:
    lhs: "Expr"
    rhs: "Expr"

    def replace(self, reps):
        return Add(
            self.lhs.replace(reps),
            self.rhs.replace(reps),
        )

    def subterms(self):
        yield self
        yield from self.lhs.subterms()
        yield from self.rhs.subterms()

    def pretty(self, n=0):
        p = f"{self.lhs.pretty(4)} + {self.rhs.pretty(3)}"
        if n >= 3:
            p = f"({p})"
        return p

    def eval(self, vals=None):
        vals = vals or dict()
        lv = self.lhs.eval(vals)
        rv = self.rhs.eval(vals)
        return lv + rv


@dataclass
class Const:
    n: int

    def replace(self, _):
        return self

    def subterms(self):
        yield self

    def pretty(self, _=0):
        return str(self.n)

    def eval(self, _vals=None):
        return self.n


@dataclass
class Var:
    name: str

    def replace(self, reps):
        return reps.get(self.name, self)

    def subterms(self):
        yield self

    def pretty(self, _=0):
        return self.name

    def eval(self, vals=None):
        vals = vals or dict()
        return vals.get(self.name, -1)


Expr = Let | Add | Const | Var


def reduce_expr(expr: Expr, check) -> Expr:
    if isinstance(expr, Var):
        return expr
    elif isinstance(expr, Const):
        if not expr.n == 0 and check("1) make zero"):
            return Const(0)
        return expr
    elif isinstance(expr, Add):
        lhs_ = reduce_expr(expr.lhs, check)
        rhs_ = reduce_expr(expr.rhs, check)
        if check("2) reduce to lhs"):
            return lhs_
        if check("3) reduce to rhs"):
            return rhs_
        return Add(lhs_, rhs_)
    elif isinstance(expr, Let):
        assignee_ = reduce_expr(expr.assignee, check)
        if check("4) reduce to assingee"):
            return assignee_
        if check(f"5) inline {expr.var!r}"):
            return reduce_expr(expr.body.replace({expr.var: assignee_}), check)
        body_ = reduce_expr(expr.body, check)
        return Let(expr.var, assignee_, body_)


if __name__ == "__main__":
    expr = Let(
        "x", Const(2), Add(Const(1), Let("y", Var("x"), Add(Const(3), Var("y"))))
    )

    def input_format(a):
        return f"$\\syn{{{a.pretty()}}}$"

    p = rtree.latex(
        lambda e: any(isinstance(a, Add) for a in e.subterms()),
        query_format=str,
        input_format=input_format,
        start_count=0,
    )
    rt = partial(reduce_expr, expr)
    p(rtree.Probe(rt, []))
    rp = rtree.reduce(p, rt)
    print(f"& \\verb|{rtree.pretty(rp)}| & {input_format(rp.input)} & true \\\\")

    print()

    p = rtree.latex(
        lambda e: e.eval() == expr.eval(),
        query_format=str,
        input_format=input_format,
        start_count=0,
    )
    rt = partial(reduce_expr, expr)
    p(rtree.Probe(rt, []))
    rp = rtree.reduce(p, rt)
    print(f"& {rtree.pretty(rp)} & {input_format(rp.input)} & true \\\\")