Skip to content
Snippets Groups Projects
viz_inventory_environment.py 8.09 KiB
Newer Older
  • Learn to ignore specific revisions
  • tuhe's avatar
    tuhe committed
    # This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
    import numpy as np
    import pygame
    from irlc.ex01.inventory_environment import InventoryEnvironment
    from irlc.utils.graphics_util_pygame import formatColor
    
    class VizInventoryEnvironment(InventoryEnvironment):
        """This is a variant of the Inventory environment which also sets up visualization.
        Most of the additional code involves calling the visualization and setting up keybindings. The one small change is that Gymnasium
        typically reset immediately on reaching the final state. I.e., the terminal state will typically not be rendered. """
    
        metadata = {'render_modes': ['human', 'rgb_array'],
                    'render_fps': 30
                    }
    
        def __init__(self, N=2, render_mode='human'):
            super(VizInventoryEnvironment, self).__init__(N)
            self.render_mode = render_mode
            self.viewer = None
            self.in_term_state = False
    
        def get_keys_to_action(self):
            k2a = [pygame.K_0, pygame.K_1, pygame.K_2, pygame.K_3]
            k2a = {(k2a[i],) : i for i in range(self.action_space.n) }
            return k2a
    
        def reset(self):
            s, info = super().reset()
            self.s = s
            self.action = None
            self.w = None
            self.reward = None
            self.render()
            return s, info
    
        def step(self, a):
            self.action = a
            print(f"Step using {a=}")
            if self.in_term_state:
                self.reward = 0
                self.k += 1
                self.in_term_state = False
                return self.s, 0, True, False, {}
            else:
                # s_next, reward, terminated, trunctated, info = super().step(a)
                w = np.random.choice(3, p=(.1, .7, .2))  # Generate random disturbance
                self.w = w
                s_next = max(0, min(2, self.s - w + a))
                reward = -(a + (self.s + a - w) ** 2)  # reward = -cost      = -g_k(x_k, u_k, w_k)
                terminated = self.k == self.N - 1  # Have we terminated? (i.e. is k==N-1)
                self.s = s_next  # update environment state
                self.k += 1
                if terminated:
                    self.in_term_state = True
                self.reward = reward
                return s_next, reward, False, False, {}
    
    
        def render(self, mode='human', agent=None, prev_action=None, reward=None):
            if self.viewer is None:
                self.viewer = InventoryViewer(self, frames_per_second=self.metadata['render_fps'])
            print(f"render: {self.action=}")
            self.viewer.update(self.agent, state=self.s, k=self.k, action=self.action, reward=self.reward, w=self.w, restart=self.action is None)
            return self.viewer.blit(render_mode=self.render_mode) #(return_rgb_array=mode == 'rgb_array')
    
        def close(self):
            self.viewer.close()
    
    
    from irlc.pacman.pacman_resources import WHITE, BLACK, Ghost
    from irlc.utils.graphics_util_pygame import GraphicsUtilGym
    
    class InventoryViewer:
        scale = 400  # Scale of a single bar.
        width = 0.4 * scale  # with of a bar.
    
        def __init__(self, inventory : InventoryEnvironment, frames_per_second=None):
            # print("BEGINNING GRAPHICS")
            self.k = 0
            self.states = []
            self.actions = []
            self.factories = []
            self.inventory = inventory
            xmin = -0.2
            xmax = inventory.N*2 + 1.4
            ymin = -0.4
            ymax = 1.4
            dx = xmax-xmin
            dy = ymax-ymin
            self.ga = GraphicsUtilGym()
            screen_width = 1300
            self.ga.begin_graphics(screen_width, dy * (screen_width / dx), local_xmin_xmax_ymin_ymax=(xmin, xmax, ymax, ymin), frames_per_second=frames_per_second, color=formatColor(0, 0, 0), title=f"Inventory environment using N={inventory.N}")
            self.last_action = None
            self.agent = None
            self.last_reward = None
            self.scale = screen_width / dx
    
    
        def close(self):
            self.ga.close()
    
        def blit(self, render_mode='human'):
            return self.ga.blit(render_mode=render_mode)
    
        def master_render(self):
            self.ga.draw_background()
            for i, factory in enumerate(self.factories):
                factory.render()
    
            if hasattr(self.inventory, '_interactive_data') and 'avg_reward_per_episode' in self.inventory._interactive_data:
                avg_reward = self.inventory._interactive_data['avg_reward_per_episode']
                episodes = self.inventory._interactive_data['completed_episodes']
                self.ga.text("sadf", (0.1, -0.1), WHITE, contents=f"Completed episodes = {episodes}",
                             size=12,
                             style='bold', anchor='w')
                self.ga.text("sadf", (0.1, -0.2), WHITE, contents=f"Average reward per episode = {avg_reward:.2f}",
                             size=12,
                             style='bold', anchor='w')
    
    
    
        def update(self, agent, k, state, action, reward, w, restart=False):
            self.agent = agent
            if restart:
                # print("Restarting the sim now..")
                self.factories = [Factory(graphics_adaptor=self.ga, x=0, y=0, k=0, state=state)]
    
            if len(self.factories) <= k:
                self.factories.append(Factory(graphics_adaptor=self.ga, x=k*2, y=0, k=k, state=state))
    
                if len(self.factories) <= self.inventory.N+1:
                    # print("Setting actions.")
                    self.factories[k-1].action = action
                    self.factories[k-1].w = w
                    self.factories[k-1].reward = reward
    
    
            self.master_render()
    
    
    class Factory:
        def __init__(self, graphics_adaptor, x, y, order=1, scale=10., k=1, state=2):
            self.ga = graphics_adaptor
            self.x = x
            self.y = y
    
            self.scale = scale
            self.s = state
            self.action = None
            self.reward = None
            self.w = None
            self.k = k
    
        def render(self):
            self.ga.rectangle(color=WHITE, x=self.x, y=0, width=1, height=1, border=1)
            self.ga.text("sadf", (self.x + 0.5, 1.1), WHITE, contents=f"day k = {self.k}",
                         size=12,
                         style='bold', anchor='c')
    
            self.ga.text("sadf", (self.x + 0.5, 0.8), WHITE, contents=f"state s_{self.k} = {self.s}",
                         size=12,
                         style='bold', anchor='c')
    
            mw = 1
            dh = 0.1
    
            rad = mw / 3 / 2
            for h in range(self.s):
                loc = self.x + rad * 2 * h + rad + mw * (3 - self.s) / 3 / 2
                self.ga.circle("sadf", (loc, rad), 0.8 * rad * self.ga.scale(), fillColor=WHITE)
    
            scale = self.ga.scale()
    
            if self.action is not None:
                self.ga.text("sdaf", (self.x + 1.5, 0.8 + dh), WHITE, contents=f"action = {self.action}", size=12, style="bold", anchor="c")
                # self.ga.line("sadf", (self.x+1.1, 0.5 + dh), (self.x+1.8, 0.5+dh), color=WHITE, width=2)
                self.ga.line("sadf", (self.x + 1.1, 0.5 + dh), (ex := self.x + 1.9, ey := 0.5 + dh), color=WHITE, width=2)
    
                self.ga.line("sadf", (ex, ey), (ex-0.05, ey-0.05), color=WHITE, width=2)
                self.ga.line("sadf", (ex, ey), (ex-0.05, ey+0.05), color=WHITE, width=2)
    
                from irlc.utils.graphics_util_pygame import Object
                if self.action is not None:
                    for a in range(self.action):
                        self.truck = Object(file="truck.jpg", graphics=self.ga, image_width=0.25 * scale)
                        self.truck.move_center_to_xy(self.x + 1 + 0.2 + a * 0.2, 0.6 + dh)
                        self.truck.flip()
                        self.truck.blit(self.ga.surf)
    
    
            if self.w is not None:
                self.ga.text("asdf", (self.x + 1.5, dh+0.05), WHITE, contents=f"w_{self.k} = {self.w}", size=12, style="bold",
                             anchor="c")
                for w in range(self.w):
                    self.customer = Object(file="customer.jpg", graphics=self.ga, image_width=0.25*scale)
                    self.customer.move_center_to_xy(self.x + 1 + 0.2 + w * 0.2, 0.3 + dh)
                    self.customer.blit(self.ga.surf)
    
            if self.reward is not None:
                self.ga.text("asdf", (self.x + 1.5, dh-0.02), WHITE, contents=f"reward = {self.reward}", size=12, style="bold",
                             anchor="c")