{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5bdbe199",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Cart Pole Reinforcement Learning\n",
    "\n",
    "This notebook demonstrates implementing a custom [Gymnasium](https://gymnasium.farama.org/index.html) environment using kdFlex. By implementing the simple [Cart Pole](https://gymnasium.farama.org/environments/classic_control/cart_pole/) scenario in a Gymnasium class, any machine learning library that is compatible with Gymnasium can be trained on the scenario. This tutorial consists of 2 parts: building the environment and training a model on the environment. \n",
    "\n",
    "Requirements:\n",
    "- [2-link Pendulum](../example_2_link_pendulum/notebook.ipynb)\n",
    "\n",
    "In this tutorial we will:\n",
    "- Implement [gymnasium.Env](https://gymnasium.farama.org/api/env/)\n",
    "  - Create the multibody\n",
    "  - Add visual geometries\n",
    "  - Setup the simulation\n",
    "  - Environment attributes\n",
    "  - Setup helpers\n",
    "  - Reset the environment\n",
    "  - Step the environment\n",
    "- Train the model\n",
    "  - Load the environment\n",
    "  - Train a DQN model\n",
    "  - Test model performance\n",
    "\n",
    "Scripts:\n",
    "- [env.py](./env.py)\n",
    "\n",
    "![](../resources/nb_images/cart_pole.png)\n",
    "\n",
    "For a more in-depth descriptions of **kdflex** concepts see [usage](usage_page)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7076028",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "First lets verify that the reinforcement learning packages are installed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0e5c5ee0",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    import stable_baselines3\n",
    "    import gymnasium\n",
    "except ModuleNotFoundError:\n",
    "    raise ModuleNotFoundError(\n",
    "        \"Please ensure that the following python packages are installed:\\n\"\n",
    "        \"    stable-baseline3 gymnasium\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f35ca1c9",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Implement gymnasium.Env\n",
    "Because classes can't be split across notebook cells, we will show snippets of code with explanations. The full class is found at [env.py](./env.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72a9b9fa",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "```python\n",
    "import numpy as np\n",
    "from math import pi\n",
    "from typing import cast\n",
    "import gymnasium as gym\n",
    "from typing import Optional\n",
    "\n",
    "from Karana.Frame import FrameContainer\n",
    "from Karana.Core import discard, allReady\n",
    "from Karana.Dynamics import (\n",
    "    Multibody,\n",
    "    HingeType,\n",
    "    StatePropagator,\n",
    "    TimedEvent,\n",
    "    PhysicalBody,\n",
    "    PhysicalHinge,\n",
    "    LinearSubhinge,\n",
    "    PinSubhinge,\n",
    ")\n",
    "from Karana.Math import UnitQuaternion, HomTran, SpatialInertia\n",
    "from Karana.Scene import (\n",
    "    BoxGeometry,\n",
    "    CylinderGeometry,\n",
    "    Color,\n",
    "    PhysicalMaterialInfo,\n",
    "    PhysicalMaterial,\n",
    "    LAYER_GRAPHICS,\n",
    ")\n",
    "from Karana.Scene import ProxyScenePart, ProxyScene\n",
    "from Karana.Models import Gravity, UniformGravity, UpdateProxyScene, SyncRealTime, OutputUpdateType\n",
    "from Karana.Math import IntegratorType\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1db746e6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Create the Multibody\n",
    "We create the cart by defining a physical cart and attaching it to the Multibody root using a slider hinge, allowing it to slide back and forth along 1 axis. \n",
    "\n",
    "The pole requires a specific inertia matrix, as the velocity which the cart is moved by a force depends on the angle of the pole. This can be calculated using a mass moment of inertia calculator. From there, the end of the pole is attached by a pin to swing like a pendulum, and transforms are defined to place where the body and joint are relative to each other.\n",
    "\n",
    "```python\n",
    "# create a cart and a pole multibody system\n",
    "def create_multibody(fc: FrameContainer) -> tuple[LinearSubhinge, PinSubhinge, Multibody]:\n",
    "    mb = Multibody(\"mb\", fc)\n",
    "\n",
    "    # add a cart of mass 1.0kg and 0.4 x 0.2 x 0.15 (L, W, H)\n",
    "    cart_height = 0.15\n",
    "    cart = PhysicalBody(\"cart\", mb)\n",
    "    cart.setSpatialInertia(SpatialInertia(1.0, np.zeros(3), np.eye(3)))\n",
    "    slider_hinge = PhysicalHinge(mb.virtualRoot(), cart, HingeType.SLIDER)\n",
    "    slider_subhinge = cast(LinearSubhinge, slider_hinge.subhinge(0))\n",
    "    slider_subhinge.setUnitAxis([1, 0, 0])\n",
    "    cart_to_slider = HomTran([0, 0, 0])\n",
    "    cart.setBodyToJointTransform(cart_to_slider)\n",
    "\n",
    "    # add a pole of mass 0.1 kg, and 0.02 x 0.5 (R, L)\n",
    "    pole_length = 0.5\n",
    "    pole = PhysicalBody(\"pole\", mb)\n",
    "    inertia_matrix = np.diag([0.0020933, 0.0020933, 0.00002])\n",
    "    pole.setSpatialInertia(SpatialInertia(0.1, [0, 0, 0], inertia_matrix))\n",
    "    pin_hinge = PhysicalHinge(cart, pole, HingeType.REVOLUTE)\n",
    "    pole_to_pin = HomTran([0.0, 0.0, -(pole_length / 2.0)])\n",
    "    pole.setBodyToJointTransform(pole_to_pin)\n",
    "    cart_to_pin = HomTran([0.0, 0.0, (cart_height / 2.0)])\n",
    "    pin_hinge.onode().setBodyToNodeTransform(cart_to_pin)\n",
    "    pin_subhinge = cast(PinSubhinge, pin_hinge.subhinge(0))\n",
    "    pin_subhinge.setUnitAxis([0, 1, 0])\n",
    "\n",
    "    # check\n",
    "    mb.ensureHealthy()\n",
    "    mb.resetData()\n",
    "    assert allReady()\n",
    "    return slider_subhinge, pin_subhinge, mb\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b00dd88b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Add Visual Geometries\n",
    "This is purely for visualization.\n",
    "\n",
    "```python\n",
    "# add a visual rectangle and cylinder\n",
    "def add_geometries(mb: Multibody, proxy_scene: ProxyScene) -> None:\n",
    "    box_geom = BoxGeometry(0.4, 0.2, 0.15)\n",
    "    cylinder_geom = CylinderGeometry(0.02, 0.5)\n",
    "    mat_info = PhysicalMaterialInfo()\n",
    "    mat_info.color = Color.FIREBRICK\n",
    "    firebrick = PhysicalMaterial(mat_info)\n",
    "    mat_info.color = Color.GOLD\n",
    "    gold = PhysicalMaterial(mat_info)\n",
    "    mat_info.color = Color.WHITE\n",
    "    white = PhysicalMaterial(mat_info)\n",
    "\n",
    "    # add geometry to bodies\n",
    "    cart_body = ProxyScenePart(\n",
    "        \"cart_body\", scene=proxy_scene, geometry=box_geom, material=firebrick, layers=LAYER_GRAPHICS\n",
    "    )\n",
    "    cart_body.attachTo(mb.getBody(\"cart\"))\n",
    "    pole_body = ProxyScenePart(\n",
    "        \"pole_body\", scene=proxy_scene, geometry=cylinder_geom, material=gold, layers=LAYER_GRAPHICS\n",
    "    )\n",
    "    pole_body.setUnitQuaternion(UnitQuaternion(pi / 2, [1, 0, 0]))\n",
    "    pole_body.attachTo(mb.getBody(\"pole\"))\n",
    "\n",
    "    # add a floor\n",
    "    ground_geom = BoxGeometry(20, 20, 0.5)\n",
    "    ground = ProxyScenePart(\n",
    "        \"ground\", scene=proxy_scene, geometry=ground_geom, material=white, layers=LAYER_GRAPHICS\n",
    "    )\n",
    "    ground.setTranslation([0, 0, -0.325])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd8f9b98",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Initialize Attributes\n",
    "In gymnasium, Environments must implement a few key attributes for users and models to train on. For simplicity, we will only implement the most important ones of action_space, observation_space, and np_random.\n",
    "- **observation_space** defines the bounds of data that is input into the model. Here, our model sees a size 4 numpy array of the cart position, cart velocity, pin angle, and pin angular velocity.\n",
    "- **action_space** defines which actions an Agent may take. Here, the agent can choose either 0 or 1, which we convert into a negative or positive force.\n",
    "- **np_random** is a random number generator which can be seeded for reproducibility.\n",
    "- **render_mode** determines whether or not to show graphics visualizations. \"none\" is faster, but shows no graphics while human renders the webscene.\n",
    "\n",
    "```python\n",
    "class CartPoleEnv(gym.Env):\n",
    "    metadata = {\"render_modes\": [\"human\", \"none\"]}\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        render_mode: Optional[str] = None,\n",
    "        port: int = 0,\n",
    "        sync_real_time: bool = False,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        self.render_mode = \"none\"\n",
    "        if render_mode and render_mode in self.metadata[\"render_modes\"]:\n",
    "            self.render_mode = render_mode\n",
    "\n",
    "        # init simulation\n",
    "        self.init_sim(port=port, sync_real_time=sync_real_time)\n",
    "\n",
    "        # observation space of [cart pos, cart vel, pin pos, pin vel]\n",
    "        low = np.array([-1.5, -2, -pi / 3, -6.0], dtype=np.float32)\n",
    "        high = np.array([1.5, 2, pi / 3, 6.0], dtype=np.float32)\n",
    "        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32, shape=(4,))\n",
    "\n",
    "        # action space of negative or positive force on cart\n",
    "        self.action_space = gym.spaces.Discrete(2)\n",
    "        self._action_to_force = {0: -10.0, 1: 10.0}\n",
    "\n",
    "        # initialize values\n",
    "        self._cart_pos = None\n",
    "        self._cart_vel = None\n",
    "        self._pin_pos = None\n",
    "        self._pin_vel = None\n",
    "        self.np_random = None\n",
    "        self._force = None\n",
    "        self._terminate = None\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e00ba7a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Setup the Simulation\n",
    "We now follow the essential simulation procedures that are explained in the 2-link pendulum example. This time however, we use attach two new callback functions to the StatePropagator.\n",
    "- {py:attr}`Karana.Dynamics.SpFunctions.pre_deriv_fns` is used to apply the force which the agent selects to the slider hinge. it must be called in the pre_derivative for the force to be properly evaluated.\n",
    "- {py:attr}`Karana.Dynamics.SpFunctions.terminate_advance_to_fns` will check for a terminating condition at the end of every hop. Here, if the cart or pole are too far deviated, the simulation terminates.\n",
    "\n",
    "```python\n",
    "    def init_sim(self, port: int = 0, sync_real_time: bool = False):\n",
    "        self._fc = FrameContainer(\"root\")\n",
    "        self._slider, self._pin, self._mb = create_multibody(self._fc)\n",
    "\n",
    "        # add visuals\n",
    "        if self.render_mode == \"human\":\n",
    "            self._cleanup_graphics, web_scene = self._mb.setupGraphics(port=port, axes=0)\n",
    "            self._proxy_scene = self._mb.getScene()\n",
    "            web_scene.defaultCamera().pointCameraAt([0, 3, 1], [0, 0, 0], [0, 0, 1])\n",
    "            add_geometries(self._mb, self._proxy_scene)\n",
    "        else:\n",
    "            self._proxy_scene = ProxyScene(f\"{self._mb.name()}_proxyscene\", self._mb.virtualRoot())\n",
    "\n",
    "        # add models\n",
    "        self._sp = StatePropagator(self._mb, integrator_type=IntegratorType.RK4)\n",
    "        # gravity model\n",
    "        # add a gravitational model to the state propagator\n",
    "        ug = Gravity(\"grav_model\", self._sp, UniformGravity(\"uniform_gravity\"), self._mb)\n",
    "        ug.getGravityInterface().setGravity(np.array([0, 0, -9.81]), 0.0, OutputUpdateType.PRE_HOP)\n",
    "        del ug\n",
    "\n",
    "        # for visualization\n",
    "        if sync_real_time:\n",
    "            SyncRealTime(\"sync_real_time\",self._sp, 1.0)\n",
    "\n",
    "        # update proxy scene model\n",
    "        UpdateProxyScene(\"update_proxy_scene\",self._sp, self._proxy_scene)\n",
    "\n",
    "        # apply slider force to move the cart\n",
    "        def apply_force(t, x):\n",
    "            self._slider.setT(self._force)\n",
    "\n",
    "        # check if pole or cart fall past a threshold\n",
    "        # max cart position is 1.5, max pole angle is ~12 (0.21)\n",
    "        def check_terminate(t, x):\n",
    "            cart_pos = self._slider.getQ()\n",
    "            pin_pos = self._pin.getQ()\n",
    "            if abs(cart_pos) > 1.5 or abs(pin_pos) > 0.21:\n",
    "                self._terminate = True\n",
    "                return True\n",
    "            return False\n",
    "\n",
    "        self._sp.fns.pre_deriv_fns[\"apply_force\"] = apply_force\n",
    "        self._sp.fns.terminate_advance_to_fns[\"check_terminate\"] = check_terminate\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75cc2967",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Setup Helpers\n",
    "These are a few helper functions for gymnasium. _get_obs composes the observation values into a numpy array for a model to train on, and _get_info is primarily for user debugging. \n",
    "\n",
    "```python\n",
    "    # what the model sees\n",
    "    def _get_obs(self):\n",
    "        return np.array(\n",
    "            [\n",
    "                float(self._cart_pos),\n",
    "                float(self._cart_vel),\n",
    "                float(self._pin_pos),\n",
    "                float(self._pin_vel),\n",
    "            ],\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "    # for user debugging\n",
    "    def _get_info(self):\n",
    "        return {\n",
    "            \"cart_pos\": self._cart_pos,\n",
    "            \"cart_vel\": self._cart_vel,\n",
    "            \"pin_pos\": self._pin_pos,\n",
    "            \"pin_vel\": self._pin_vel,\n",
    "        }\n",
    "\n",
    "    def close(self):\n",
    "        del self._slider, self._pin, self._proxy_scene\n",
    "        discard(self._sp)\n",
    "        if self.render_mode == \"human\":\n",
    "            self._cleanup_graphics()\n",
    "        discard(self._mb)\n",
    "        discard(self._fc)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75f22fa6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Reset the Environment\n",
    "Gymnasium composes scenarios into episodes and steps. An episode represents a learning scenario, and each step represents the choices that the model makes in each episode. For example, a cartpole with a random initial position and velocity is an episode. And the many forces the model applies at various timesteps are the steps.\n",
    "\n",
    "Reset is used to initialize a fresh episode. This involves clearing out the state of the previous episode and setting up new random initial values.\n",
    "\n",
    "```python\n",
    "    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):\n",
    "        super().reset(seed=seed)\n",
    "        if not self.np_random:\n",
    "            self.np_random, _ = gym.utils.seeding.np_random(seed)\n",
    "\n",
    "        # init values\n",
    "        self._cart_pos = self.np_random.uniform(-0.1, 0.1)\n",
    "        self._cart_vel = self.np_random.uniform(-0.1, 0.1)\n",
    "        self._pin_pos = self.np_random.uniform(-0.1, 0.1)\n",
    "        self._pin_vel = self.np_random.uniform(-0.1, 0.1)\n",
    "        self._force = 0.0\n",
    "        self._terminate = False\n",
    "\n",
    "        # set our state\n",
    "        self._mb.resetData()\n",
    "        self._slider.setQ(self._cart_pos)\n",
    "        self._slider.setU(self._cart_vel)\n",
    "        self._pin.setQ(self._pin_pos)\n",
    "        self._pin.setU(self._pin_vel)\n",
    "        t_init = np.timedelta64(0, \"ns\")\n",
    "        self._sp.setTime(t_init)\n",
    "        self._sp.setState(self._sp.assembleState())\n",
    "\n",
    "        # setup timed event\n",
    "        h = np.timedelta64(int(1e7), \"ns\")\n",
    "        t = TimedEvent(\"hop_size\", h, lambda _: None, False)\n",
    "        t.period = h\n",
    "        self._sp.registerTimedEvent(t)\n",
    "        del t\n",
    "        self._proxy_scene.update()\n",
    "        return self._get_obs(), self._get_info()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92f67444",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Step the Environment\n",
    "Every time the model makes a decision, it calls step with the action it chose. Therefore, we convert the action to a force, advance our simulator, and return the new observation values.\n",
    "\n",
    "By default, gymnasium expects observation, reward, terminated, truncated, info to be returned. As per the Cart Pole environment, the model is given a reward of 1.0 for each step it survives, and Truncated is always false. You can read about terminated vs truncated [here](https://farama.org/Gymnasium-Terminated-Truncated-Step-API).\n",
    "\n",
    "```python\n",
    "    def step(self, action):\n",
    "        self._force = self._action_to_force[action]\n",
    "        self._sp.advanceBy(0.02)\n",
    "        self._cart_pos = self._slider.getQ()\n",
    "        self._cart_vel = self._slider.getU()\n",
    "        self._pin_pos = self._pin.getQ()\n",
    "        self._pin_vel = self._pin.getU()\n",
    "        # observation, reward, terminated, truncated, info\n",
    "        return self._get_obs(), 1.0, self._terminate, False, self._get_info()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43988613",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Train the model\n",
    "Great! You are now able to train a model from any library compatible with Gymnasium on this environment. For this tutorial, we will use a model from [Stable Baselines3](https://stable-baselines3.readthedocs.io/en/master/) to balance the cartpole. Before we begin, make sure to add these dependencies by running this in your current python environment:\n",
    "```bash\n",
    "pip install stable-baselines3\n",
    "pip install gymnasium\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "708d4011",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from env import CartPoleEnv\n",
    "import gymnasium as gym\n",
    "from stable_baselines3 import DQN\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecfdd8d5",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Load the Environment\n",
    "\n",
    "Here, we define our training environment and our model. As long as you register an environment into gymnasium, you can use that same environment with gym.make. Since Stable Baselines3 expects environments to be vectorized, for our simple scenario we use a dummy vectorized environment wrapper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78a43647",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WebUI] Listening at http://newton:39199\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"300px\"\n",
       "            src=\"http://newton:39199\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7828d4329a00>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gym.register(\n",
    "    id=\"kdFlex-CartPole\",\n",
    "    entry_point=CartPoleEnv,\n",
    "    max_episode_steps=500,\n",
    ")\n",
    "\n",
    "train_env = gym.make(\"kdFlex-CartPole\", render_mode=\"human\")\n",
    "train_env = DummyVecEnv([lambda: train_env])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58033980",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Train a DQN Model\n",
    "\n",
    "Because the action space is discrete, a Deep Q-Network (DQN) model is used with some basic hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "919a51f9",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = DQN(\n",
    "    policy=\"MlpPolicy\",\n",
    "    env=train_env,\n",
    "    learning_rate=1e-3,\n",
    "    learning_starts=1000,\n",
    "    buffer_size=50000,\n",
    "    batch_size=64,\n",
    "    target_update_interval=250,\n",
    "    exploration_final_eps=0.02,\n",
    "    policy_kwargs=dict(net_arch=[64, 64]),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c1180295",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# check if this is running in a test environment\n",
    "if os.getenv(\"DTEST_RUNNING\", True):\n",
    "    # in a testing environment, run for less\n",
    "    model.learn(total_timesteps=1000)\n",
    "else:\n",
    "    # this will take around 1-4 minutes\n",
    "    model.learn(total_timesteps=100000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d101f518",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Test the Model\n",
    "Our training environment was sped up, so to properly visualize the model we will create a separate test environment that is synced to real time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aca87959",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WebUI] Listening at http://newton:33863\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"300px\"\n",
       "            src=\"http://newton:33863\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x782aa3ffeff0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_env.close()\n",
    "test_env = gym.make(\"kdFlex-CartPole\", render_mode=\"human\", sync_real_time=True)\n",
    "test_env = DummyVecEnv([lambda: test_env])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25793a85",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "You can now run your model on a test environment episode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fba09fbc",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "terminated = False\n",
    "obs = test_env.reset()\n",
    "while not terminated:\n",
    "    action, _states = model.predict(obs)\n",
    "    obs, reward, terminated, info = test_env.step(action)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4abc7e6a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We also save a working model you can use inside ./models/demo_models if you have difficulty training the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "421eb34e",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = DQN.load(\"./models/demo_model\", env=test_env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9fb12e66",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# cleanup\n",
    "test_env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a7cc646",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Summary\n",
    "Amazing! You can now implement a custom gymnasium environment using kdFlex. In doing so, you have also opened yourself up to all the machine learning libraries that interface with gymnasium. \n",
    "\n",
    "## Further Readings\n",
    "- [AtrvJr Reinforcement Learning](../example_atrvjr_learning/notebook.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "argv": [
    "python",
    "-m",
    "ipykernel_launcher",
    "-f",
    "{connection_file}"
   ],
   "display_name": "Python 3 (ipykernel)",
   "env": null,
   "interrupt_mode": "signal",
   "kernel_protocol_version": "5.5",
   "language": "python",
   "metadata": {
    "debugger": true
   },
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "name": "notebook.ipynb"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
