Skip to content
Snippets Groups Projects
tests_week01.py 5.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • tuhe's avatar
    tuhe committed
    # This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
    from unitgrade import Report
    import irlc
    # from irlc.ex01.frozen_lake import FrozenAgentDownRight
    import gymnasium as gym
    from unitgrade import UTestCase
    from irlc.ex01.inventory_environment import InventoryEnvironment, simplified_train, RandomAgent
    from unitgrade import Capturing2
    import numpy as np
    from gymnasium.envs.toy_text.frozen_lake import RIGHT, DOWN  # The down and right-actions; may be relevant.
    from irlc.ex01.pacman_hardcoded import GoAroundAgent, layout
    from irlc.pacman.pacman_environment import PacmanEnvironment
    from irlc import Agent, train
    from irlc.ex01.bobs_friend import BobFriendEnvironment, AlwaysAction_u1, AlwaysAction_u0
    
    
    class Problem1BobsFriend(UTestCase):
        def test_a_env_basic(self):
            env = BobFriendEnvironment()
            s0, _ = env.reset()
            self.assertEqual(s0, 20, msg="Reset must return the initial state, i.e. the amount of money we start out with")
    
        def test_a_env_u0(self):
            env = BobFriendEnvironment()
            env.reset()
            s1, r, done, _, _ = env.step(0)
            self.assertEqual(r, 2, msg="When taking action u0, we must get a reward of 2.")
            self.assertEqual(s1, 22, msg="When taking action u0, we must end in state x1=22")
            self.assertEqual(done, True, msg="After taking an action, the environment must terminate")
    
    class Problem2BobsPolicy(UTestCase):
        def test_a_env_u1(self):
            env = BobFriendEnvironment()
            env.reset()
            s1, r, done, _, _ = env.step(1)
            print(r)
            self.assertTrue(r == 12 or r == -20, msg="When taking action u1, we must get a reward of 0 or 12.")
            self.assertTrue(s1 == 0 or s1 == 32, msg="When taking action u1, we must end in state x1=0 or x1 = 34")
            self.assertEqual(done, True, msg="After taking an action, the environment must terminate")
    
        def test_b_always_action_u0(self):
            env = BobFriendEnvironment()
            stats, _ = train(env, AlwaysAction_u0(env), num_episodes=1000)
            avg = np.mean( [stat['Accumulated Reward'] for stat in stats] )
            self.assertL2(avg, 2, msg="Average reward when we always take action u=0 must be 2.")
    
        def test_b_always_action_u1(self):
            env = BobFriendEnvironment()
            stats, _ = train(env, AlwaysAction_u1(env), num_episodes=10000)
            avg = np.mean( [stat['Accumulated Reward'] for stat in stats] )
            self.assertL2(avg, 4, tol=0.5, msg="Average reward when we always take action u=0 must be about 4.")
    
        def test_b_always_action_u1_starting_200(self):
            env = BobFriendEnvironment(x0=200)
            stats, _ = train(env, AlwaysAction_u1(env), num_episodes=10000)
            avg = np.mean( [stat['Accumulated Reward'] for stat in stats] )
            self.assertL2(avg, -42, tol=4, msg="Average reward when we always take action u=0 must be about 4.")
    
        def test_b_always_action_u0_starting_200(self):
            env = BobFriendEnvironment(x0=200)
            stats, _ = train(env, AlwaysAction_u0(env), num_episodes=10000)
            avg = np.mean( [stat['Accumulated Reward'] for stat in stats] )
            self.assertL2(avg, 20, msg="Average reward when we always take action u=0 must be about 4.")
    
    
    
    class Problem5PacmanHardcoded(UTestCase):
        """ Test the hardcoded pacman agent """
        def test_pacman(self):
            env = PacmanEnvironment(layout_str=layout)
            agent = GoAroundAgent(env)
            stats, _ = train(env, agent, num_episodes=1)
            self.assertEqual(stats[0]['Length'] < 100, True)
    
    
    class Problem6ChessTournament(UTestCase):
        def test_chess(self):
            """ Test the correct result in the little chess-tournament """
            from irlc.ex01.chess import main
            with Capturing2() as c:
                main()
            # Extract the numbers from the console output.
            print("Numbers extracted from console output was")
            print(c.numbers)
            self.assertLinf(c.numbers[-2], 26/33, tol=0.05)
    
    class Problem3InventoryInventoryEnvironment(UTestCase):
        def test_environment(self):
            env = InventoryEnvironment()
            # agent = RandomAgent(env)
            stats, _ = train(env, Agent(env), num_episodes=2000, verbose=False)
            avg_reward = np.mean([stat['Accumulated Reward'] for stat in stats])
            self.assertLinf(avg_reward, tol=0.6)
    
        def test_random_agent(self):
            env = InventoryEnvironment()
            stats, _ = train(env, RandomAgent(env), num_episodes=2000, verbose=False)
            avg_reward = np.mean([stat['Accumulated Reward'] for stat in stats])
            self.assertLinf(avg_reward, tol=0.6)
    
    class Problem4InventoryTrain(UTestCase):
        def test_simplified_train(self):
            env = InventoryEnvironment()
            agent = Agent(env)
            avg_reward_simplified_train = np.mean([simplified_train(env, agent) for i in range(1000)])
            self.assertLinf(avg_reward_simplified_train, tol=0.5)
    
    # class FrozenLakeTest(UTestCase):
    #     def test_frozen_lake(self):
    #         env = gym.make("FrozenLake-v1")
    #         agent = FrozenAgentDownRight(env)
    #         s = env.reset()
    #         for k in range(10):
    #             self.assertEqual(agent.pi(s, k), DOWN if k % 2 == 0 else RIGHT)
    
    
    class Week01Tests(Report): #240 total.
        title = "Tests for week 01"
        pack_imports = [irlc]
        individual_imports = []
        questions = [
            (Problem1BobsFriend, 10),
            (Problem2BobsPolicy, 10),
            (Problem3InventoryInventoryEnvironment, 10),
            (Problem4InventoryTrain, 10),
            (Problem5PacmanHardcoded, 10),
            (Problem6ChessTournament, 10),      # Week 1: Everything
                     ]
    
    if __name__ == '__main__':
        from unitgrade import evaluate_report_student
        evaluate_report_student(Week01Tests())