From 2520bd75b52df6e06e2c5f3132ecdb0e1b9eaf34 Mon Sep 17 00:00:00 2001 From: Lionel Miller Date: Sun, 2 Feb 2020 15:24:49 +0300 Subject: [PATCH] [week6] Fix accidental TimeLimit in CartPole --- week06_policy_based/reinforce_tensorflow.ipynb | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/week06_policy_based/reinforce_tensorflow.ipynb b/week06_policy_based/reinforce_tensorflow.ipynb index c46aaaacf..bf36035f3 100644 --- a/week06_policy_based/reinforce_tensorflow.ipynb +++ b/week06_policy_based/reinforce_tensorflow.ipynb @@ -42,14 +42,13 @@ "source": [ "import gym\n", "import numpy as np\n", - "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "env = gym.make(\"CartPole-v0\")\n", "\n", "# gym compatibility: unwrap TimeLimit\n", - "if hasattr(env, 'env'):\n", + "if hasattr(env, '_max_episode_steps'):\n", " env = env.env\n", "\n", "env.reset()\n", @@ -260,7 +259,7 @@ "metadata": {}, "outputs": [], "source": [ - "def generate_session(t_max=1000):\n", + "def generate_session(env, t_max=1000):\n", " \"\"\"play env with REINFORCE agent and train at the session end\"\"\"\n", "\n", " # arrays to record session\n", @@ -302,7 +301,7 @@ "\n", "for i in range(100):\n", "\n", - " rewards = [generate_session() for _ in range(100)] # generate new sessions\n", + " rewards = [generate_session(env) for _ in range(100)] # generate new sessions\n", "\n", " print(\"mean reward:%.3f\" % (np.mean(rewards)))\n", "\n", @@ -326,10 +325,9 @@ "source": [ "# record sessions\n", "import gym.wrappers\n", - "env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),\n", - " directory=\"videos\", force=True)\n", - "sessions = [generate_session() for _ in range(100)]\n", - "env.close()" + "monitor_env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True)\n", + "sessions = [generate_session(monitor_env) for _ in range(100)]\n", + "monitor_env.close()" ] }, {