diff --git a/irlc/gridworld/demo_agents/hidden_agents.py b/irlc/gridworld/demo_agents/hidden_agents.py index d831b118fac4c28ab77e3f44457c63ca727fd671..641466289fdbfe076261ac72ed47cc7898dc911f 100644 --- a/irlc/gridworld/demo_agents/hidden_agents.py +++ b/irlc/gridworld/demo_agents/hidden_agents.py @@ -63,6 +63,10 @@ class PolicyEvaluationAgent2(TabularAgent): self.policy[s][a] = 1/len(mdp.A(s)) super().__init__(env, gamma) + def reset(self): + self.v = defaultdict(lambda: 0) + + def pi(self, s,k, info=None): # TODO: 1 lines missing. @@ -162,6 +166,11 @@ class ValueIterationAgent3(TabularAgent): self.policy[s][a] = 1/len(mdp.A(s)) super().__init__(env, gamma, epsilon=epsilon) + def reset(self): + self.v = defaultdict(lambda: 0) + self.Q.q_.clear() + + def pi(self, s,k, info=None): from irlc import Agent