From 36ccf2376d55ffa7cc676f9bfb39ac94bf658600 Mon Sep 17 00:00:00 2001
From: Christian Gram Kalhauge <chrg@dtu.dk>
Date: Tue, 5 Mar 2024 11:22:25 +0100
Subject: [PATCH] Fix small problems

---
 pyrtree/rtree.py                              | 24 +++++++++----------
 rtree/src/Control/Monad/IRTree.hs             | 15 ++++++------
 rtree/test/expected/double-let-expr-ired      |  1 -
 .../expected/double-overloading-let-expr-ired |  1 -
 rtree/test/expected/small-let-expr-ired       |  1 -
 rtree/test/expected/small-opr-expr-ired       |  1 -
 6 files changed, 19 insertions(+), 24 deletions(-)

diff --git a/pyrtree/rtree.py b/pyrtree/rtree.py
index bec022b..d484af3 100644
--- a/pyrtree/rtree.py
+++ b/pyrtree/rtree.py
@@ -11,8 +11,8 @@ class ReducePath:
         self.path.append(choice)
         return self
 
-    def didGuess(self):
-        return self.index > len(self.path)
+    def left(self):
+        return self.index - len(self.path)
 
     def dispensable(self):
         self.index += 1
@@ -25,16 +25,16 @@ class ReducePath:
 def reduce(predicate, rtree):
     r = ReducePath([])
     i = rtree(r)
-
-    if not predicate(i):
-        return None
-
-    while r.didGuess():
-        # Explore the left tree
-        i = rtree(r.explore(True))
-        # If the predcate fails, move right
-        r.path[-1] = predicate(i)
-
+    it = rtree(r.explore(True))
+    # While we don't consume all choices going down the true branch
+    while r.left() >= 0:
+        if predicate(it):
+            # If true update the valid input
+            i = it
+        else:
+            # If false we have to go down the left branch.
+            r.path[-1] = False
+        it = rtree(r.explore(True))
     return i
 
 
diff --git a/rtree/src/Control/Monad/IRTree.hs b/rtree/src/Control/Monad/IRTree.hs
index 706c9db..b8c3a22 100644
--- a/rtree/src/Control/Monad/IRTree.hs
+++ b/rtree/src/Control/Monad/IRTree.hs
@@ -92,15 +92,14 @@ reduceT
   -> IRTreeT l t i
   -> m i
 reduceT lift_ p rt = do
-  Seq.empty & fix \rec sq -> do
-    -- Try to run the true branch.
+  (k', _, _) <- _probe Seq.empty
+  (\f -> f Seq.empty k') $ fix \rec sq k -> do
     (i, l, left) <- _probe (sq Seq.|> True)
-    p l i >>= \case
-      -- If predicate is true, and there is choices left
-      True | left > 0 -> rec (sq Seq.|> True)
-      -- If predicate is false (and stable)
-      False | left >= 0 -> rec (sq Seq.|> False)
-      _ow -> pure i
+    if left < 0
+      then pure k
+      else do
+        t <- p l i
+        rec (sq Seq.|> t) (if t then i else k)
  where
   _probe sq = lift_ . probeT rt . fromChoiceList $ toList sq
 {-# INLINE reduceT #-}
diff --git a/rtree/test/expected/double-let-expr-ired b/rtree/test/expected/double-let-expr-ired
index 3dce08b..35bdda2 100644
--- a/rtree/test/expected/double-let-expr-ired
+++ b/rtree/test/expected/double-let-expr-ired
@@ -3,4 +3,3 @@
 111: 1 False
 1101: 2 False
 11001: 3 False
-11000: 1 + 2 True
diff --git a/rtree/test/expected/double-overloading-let-expr-ired b/rtree/test/expected/double-overloading-let-expr-ired
index 5d756a4..4f29850 100644
--- a/rtree/test/expected/double-overloading-let-expr-ired
+++ b/rtree/test/expected/double-overloading-let-expr-ired
@@ -3,4 +3,3 @@
 111: 2 False
 1101: 2 False
 11001: 4 False
-11000: 2 + 2 True
diff --git a/rtree/test/expected/small-let-expr-ired b/rtree/test/expected/small-let-expr-ired
index 8ec2335..ab4bde1 100644
--- a/rtree/test/expected/small-let-expr-ired
+++ b/rtree/test/expected/small-let-expr-ired
@@ -2,4 +2,3 @@
 11: 2 False
 101: 1 False
 1001: 3 False
-1000: 2 + 1 True
diff --git a/rtree/test/expected/small-opr-expr-ired b/rtree/test/expected/small-opr-expr-ired
index e8e5ee3..63b780a 100644
--- a/rtree/test/expected/small-opr-expr-ired
+++ b/rtree/test/expected/small-opr-expr-ired
@@ -1,4 +1,3 @@
 1: 1 False
 01: 2 False
 001: 3 False
-000: 1 + 2 True
-- 
GitLab