{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "defd5bb9",
   "metadata": {},
   "source": [
    "# ATRVjr Reinforcement Learning\n",
    "\n",
    "This tutorial extends [Cart Pole Reinforcement Learning](../example_cartpole/notebook.ipynb) by training a model to drive an ATRVjr vehicle to a target. This tutorial consists of 2 parts: building the environment and training a model on the environment. \n",
    "\n",
    "Requirements:\n",
    "- [Cart Pole Reinforcement Learning](../example_cartpole/notebook.ipynb)\n",
    "- [ATRVjr driving](../example_atrvjr_drive/notebook.ipynb)\n",
    "\n",
    "In this tutorial we will:\n",
    "- Implement [gymnasium.Env](https://gymnasium.farama.org/api/env/)\n",
    "  - Create the multibody\n",
    "  - Create Environment\n",
    "  - Setup the simulation\n",
    "  - Environment attributes\n",
    "  - Setup helpers\n",
    "  - Reset the environment\n",
    "  - Step the environment\n",
    "- Train the model\n",
    "  - Load the environment\n",
    "  - Setup callbacks\n",
    "  - Train a SAC model\n",
    "  - Test model performance\n",
    "\n",
    "![](../resources/nb_images/atrv_learning.png)\n",
    "\n",
    "Scripts:\n",
    "- [env.py](./env.py)\n",
    "\n",
    "\n",
    "For a more in-depth descriptions of **kdflex** concepts see [usage](usage_page)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b26422ea",
   "metadata": {},
   "source": [
    "First let's verify that the reinforcement learning packages are installed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "96d2a326",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import stable_baselines3\n",
    "    import gymnasium\n",
    "except ModuleNotFoundError:\n",
    "    raise ModuleNotFoundError(\n",
    "        \"Please ensure that the following python packages are installed:\\n\"\n",
    "        \"    stable_baseline3 gymnasium\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ddabfc8",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "# Implement gymnasium.Env\n",
    "Because classes can't be split across notebook cells, we will show snippets of code with explanations. The full class is found at [env.py](./env.py).\n",
    "\n",
    "```python\n",
    "from typing import Optional, cast\n",
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "from pathlib import Path\n",
    "from math import pi, sqrt\n",
    "\n",
    "from Karana.Frame import FrameContainer\n",
    "from Karana.Core import discard, allDestroyed\n",
    "from Karana.Dynamics.SOADyn_types import (\n",
    "    SubGraphDS,\n",
    "    PinSubhingeDS,\n",
    "    Linear3SubhingeDS,\n",
    "    SphericalSubhingeDS,\n",
    "    HingeDS,\n",
    "    BodyWithContextDS,\n",
    "    ScenePartSpecDS,\n",
    ")\n",
    "from Karana.Dynamics import (\n",
    "    Multibody,\n",
    "    HingeType,\n",
    "    StatePropagator,\n",
    "    TimedEvent,\n",
    "    PhysicalBody,\n",
    "    PhysicalHinge,\n",
    ")\n",
    "from Karana.Math import UnitQuaternion, HomTran\n",
    "from Karana.Scene import (\n",
    "    BoxGeometry,\n",
    "    CylinderGeometry,\n",
    "    Color,\n",
    "    PhysicalMaterialInfo,\n",
    "    PhysicalMaterial,\n",
    "    LAYER_GRAPHICS,\n",
    "    LAYER_ALL,\n",
    "    LAYER_COLLISION,\n",
    "    Texture,\n",
    ")\n",
    "from Karana.Scene import ProxyScenePart, ProxyScene\n",
    "from Karana.Scene import CoalScene\n",
    "from Karana.Collision import FrameCollider\n",
    "from Karana.Models import Gravity, UniformGravity, UpdateProxyScene, SyncRealTime, PenaltyContact, OutputUpdateType\n",
    "from Karana.Math import IntegratorType\n",
    "from Karana.KUtils.BasicPrefab import BasicPrefabDS\n",
    "\n",
    "\n",
    "# helper to wrap angles\n",
    "def wrap_to_pi(x):\n",
    "    return (x + np.pi) % (2 * np.pi) - np.pi\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61df3ec3",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Create the Multibody\n",
    "We create and edit our multibody by importing a URDF file into a SOADyn_types.SubGraphDS. This is the same procedure as [ATRVjr driving](../example_atrvjr_drive/notebook.ipynb).\n",
    "\n",
    "\n",
    "```python\n",
    "def createMbody(fc: FrameContainer) -> Multibody:\n",
    "    urdf_file = Path().absolute().parent / \"resources\" / \"atrvjr\" / \"atrvjr.urdf\"\n",
    "    dark = BasicPrefabDS.fromFile(urdf_file).params.subtree\n",
    "    multibody_info = SubGraphDS(name=dark.name, base_bodies=dark.base_bodies)\n",
    "    # convert to full6DOF hinge between root frame and rover.\n",
    "    multibody_info.base_bodies[0].body.hinge = HingeDS(\n",
    "        hinge_type=HingeType.FULL6DOF,\n",
    "        subhinges=[\n",
    "            Linear3SubhingeDS(prescribed=False),\n",
    "            SphericalSubhingeDS(prescribed=False),\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    # unlock locked Hinges for wheels from the URDF\n",
    "    for i in [0, 1]:\n",
    "        multibody_info.base_bodies[0].children[i].body.hinge = HingeDS(\n",
    "            hinge_type=HingeType.REVOLUTE,\n",
    "            subhinges=[PinSubhingeDS(prescribed=False, unit_axis=np.array([0.0, 0.0, 1.0]))],\n",
    "        )\n",
    "    for i in [2, 3]:\n",
    "        multibody_info.base_bodies[0].children[i].body.hinge = HingeDS(\n",
    "            hinge_type=HingeType.REVOLUTE,\n",
    "            subhinges=[PinSubhingeDS(prescribed=False, unit_axis=np.array([0.0, 0.0, -1.0]))],\n",
    "        )\n",
    "        \n",
    "    # Remove all collision meshes. We want to add these manually.\n",
    "    def removeCollision():\n",
    "        def _removeParts(parts: list[ScenePartSpecDS]):\n",
    "            idx = []\n",
    "            for k,p in enumerate(parts):\n",
    "                if p.layers == LAYER_COLLISION:\n",
    "                    idx.append(k)\n",
    "            idx.reverse()\n",
    "            for k in idx:\n",
    "                parts.pop(k)\n",
    "\n",
    "        def _recurse(b: BodyWithContextDS):\n",
    "            _removeParts(b.body.scene_parts)\n",
    "            for c in b.children:\n",
    "                _recurse(c)\n",
    "\n",
    "        for base in multibody_info.base_bodies:\n",
    "            _recurse(base)\n",
    "\n",
    "    removeCollision()\n",
    "\n",
    "    multibody = multibody_info.toMultibody(fc)\n",
    "    multibody.ensureHealthy()\n",
    "    multibody.resetData()\n",
    "    assert multibody.isReady()\n",
    "    return multibody\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "944be11b",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Create Environment\n",
    "Here, we attach cylinder geometries for the wheels. Additionally, we set up the ground and apply a custom grid texture from a .png file using {py:meth}`Karana.Scene.Texture.lookupOrCreateTexture`. We add walls around all four sides to detect when the car is going out of bounds by checking for collisions.\n",
    "\n",
    "```python\n",
    "def createEnvironment(\n",
    "    size: int, mb: Multibody, proxy_scene: ProxyScene\n",
    ") -> tuple[PhysicalBody, PhysicalBody, PhysicalBody, PhysicalBody]:\n",
    "    thickness = 0.25\n",
    "    height = 2.5\n",
    "    mat_info = PhysicalMaterialInfo()\n",
    "    mat_info.color = Color.GRAY\n",
    "    gray = PhysicalMaterial(mat_info)\n",
    "    # add wheels to our atrv model\n",
    "    wheel_geom = CylinderGeometry(0.125, 0.075)\n",
    "    left_front_wheel = mb.getBody(\"left_front_wheel_link\")\n",
    "    left_front_wheel_node = ProxyScenePart(\n",
    "        name=\"left_front_wheel_link_node\", scene=proxy_scene, geometry=wheel_geom, material=gray\n",
    "    )\n",
    "    left_front_wheel_node.attachTo(left_front_wheel)\n",
    "    left_front_wheel_node.setTranslation([0, 0, 0.03])\n",
    "    left_front_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))\n",
    "\n",
    "    left_rear_wheel = mb.getBody(\"left_rear_wheel_link\")\n",
    "    left_rear_wheel_node = ProxyScenePart(\n",
    "        name=\"left_rear_wheel_link_node\", scene=proxy_scene, geometry=wheel_geom, material=gray\n",
    "    )\n",
    "    left_rear_wheel_node.attachTo(left_rear_wheel)\n",
    "    left_rear_wheel_node.setTranslation([0, 0, 0.03])\n",
    "    left_rear_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))\n",
    "\n",
    "    right_front_wheel = mb.getBody(\"right_front_wheel_link\")\n",
    "    right_front_wheel_node = ProxyScenePart(\n",
    "        name=\"right_front_wheel_link_node\", scene=proxy_scene, geometry=wheel_geom, material=gray\n",
    "    )\n",
    "    right_front_wheel_node.attachTo(right_front_wheel)\n",
    "    right_front_wheel_node.setTranslation([0, 0, 0.03])\n",
    "    right_front_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))\n",
    "\n",
    "    right_rear_wheel = mb.getBody(\"right_rear_wheel_link\")\n",
    "    right_rear_wheel_node = ProxyScenePart(\n",
    "        name=\"right_rear_wheel_link_node\", scene=proxy_scene, geometry=wheel_geom, material=gray\n",
    "    )\n",
    "    right_rear_wheel_node.attachTo(right_rear_wheel)\n",
    "    right_rear_wheel_node.setTranslation([0, 0, 0.03])\n",
    "    right_rear_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))\n",
    "\n",
    "    # add ground\n",
    "    ground_geom = BoxGeometry(size, size, thickness)\n",
    "    mat_info.color = Color.WHITE\n",
    "    texture_file = Path().absolute().parent / \"resources\" / \"atrvjr\" / \"grid.png\"\n",
    "    grid_texture = Texture.lookupOrCreateTexture(str(texture_file))\n",
    "    mat_info.color_map = grid_texture\n",
    "    grid = PhysicalMaterial(mat_info)\n",
    "    ground = ProxyScenePart(\n",
    "        \"ground\", scene=proxy_scene, geometry=ground_geom, material=grid, layers=LAYER_ALL\n",
    "    )\n",
    "    ground.setTranslation([0, 0, -0.1])\n",
    "\n",
    "    # add walls\n",
    "    mat_info.color = Color.KHAKI\n",
    "    mat_info.color_map = None\n",
    "    khaki = PhysicalMaterial(mat_info)\n",
    "    wall_geom = BoxGeometry(size, thickness, height)\n",
    "    north_wall = ProxyScenePart(\n",
    "        \"north_wall\", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL\n",
    "    )\n",
    "    north_wall.setTranslation([size / 2.0 + thickness / 2, 0, height / 2 - thickness])\n",
    "    north_wall.setUnitQuaternion(UnitQuaternion([1 / sqrt(2), 1 / sqrt(2), 0, 0]))\n",
    "    south_wall = ProxyScenePart(\n",
    "        \"south_wall\", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL\n",
    "    )\n",
    "    south_wall.setTranslation([-size / 2.0 - thickness / 2, 0, height / 2 - thickness])\n",
    "    south_wall.setUnitQuaternion(UnitQuaternion([-1 / sqrt(2), 1 / sqrt(2), 0, 0]))\n",
    "    west_wall = ProxyScenePart(\n",
    "        \"west_wall\", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL\n",
    "    )\n",
    "    west_wall.setTranslation([0, size / 2.0 + thickness / 2, height / 2 - thickness])\n",
    "    east_wall = ProxyScenePart(\n",
    "        \"east_wall\", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL\n",
    "    )\n",
    "    east_wall.setTranslation([0, -size / 2.0 - thickness / 2, height / 2 - thickness])\n",
    "\n",
    "    return left_front_wheel, left_rear_wheel, right_front_wheel, right_rear_wheel\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e86473f3",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Initialize Attributes\n",
    "Similar to [Cart Pole](../example_cartpole/notebook.ipynb), it is necessary to implement a few key attributes. Here is the breakdown for this scenario:\n",
    "- **observation space:** The agent is able to see the target position, car position, car linear velocities, car roll, car yaw velocity, sine of the angle deviation from target, and the cosine of the angle deviation.\n",
    "  - roll is passed so the agent can see if it turns too sharply and flips over\n",
    "  - it is common to use sin and cosine for angles to preserve continuity in values, as a jump between -3.14 and 3.14 radians (which are the same) can mess up training\n",
    "  - minimum and maximum velocities were found through direct testing\n",
    "- **action space:** The agent is able to select 2 continuous values between -1 and 1 that is scaled by self._torque_magnitude to be applied to the left-side and right-side wheels.\n",
    "- **np_random** is a random number generator which can be seeded for reproducibility.\n",
    "- **render_mode** determines whether or not to show graphics visualizations. \"none\" is faster, but shows no graphics while human renders the webscene.\n",
    "\n",
    "```python\n",
    "class ATRVEnv(gym.Env):\n",
    "    metadata = {\"render_modes\": [\"human\", \"none\"]}\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        render_mode: Optional[str] = None,\n",
    "        port: int = 0,\n",
    "        sync_real_time: bool = False,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        self.render_mode = \"none\"\n",
    "        if render_mode and render_mode in self.metadata[\"render_modes\"]:\n",
    "            self.render_mode = render_mode\n",
    "        self.size = 30\n",
    "        self._bound = self.size / 2.0 - 3.0\n",
    "        self._torque_magnitude = 20.0\n",
    "\n",
    "        # initialize the simulation\n",
    "        self.init_sim(port, sync_real_time)\n",
    "\n",
    "        # [txpos, typos, x, y, vx, vy, roll, yaw_vel, sin(angle_diff), cos(angle_diff)]\n",
    "        low = np.array(\n",
    "            [\n",
    "                -self._bound,\n",
    "                -self._bound,\n",
    "                -self.size,\n",
    "                -self.size,\n",
    "                -6.0,\n",
    "                -6.0,\n",
    "                -0.25,\n",
    "                -6.0,\n",
    "                -1.0,\n",
    "                -1.0,\n",
    "            ],\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "        high = np.array(\n",
    "            [self._bound, self._bound, self.size, self.size, 6.0, 6.0, 0.25, 6.0, 1.0, 1.0],\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)\n",
    "\n",
    "        # define action space\n",
    "        # torque on [left, right]\n",
    "        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)\n",
    "\n",
    "        # initialize class variables\n",
    "        self._target = None\n",
    "        self._left = None\n",
    "        self._right = None\n",
    "        self._target_pos = None\n",
    "        self._agent_pos = None\n",
    "        self._agent_vel = None\n",
    "        self._agent_roll = None\n",
    "        self._agent_yaw_vel = None\n",
    "        self._angle_diff = None\n",
    "        self.np_random = None\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7ec5770",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Setup the Simulation\n",
    "Here the environment follows the same simulation procedures all other notebooks. This time, the {py:class}`Karana.Frame.ChainedFrameToFrame` as well as the 6 degree of freedom hinge between the origin and multibody is saved.\n",
    "- {py:class}`Karana.Frame.ChainedFrameToFrame` (self._f2f) allows for querying of the relative pose, spatial velocity, and spatial acceleration. Since this is between the world origin and the ATRVjr multibody, then this will query for the world coordinates of the multibody.\n",
    "- {py:class}`Karana.Dynamics.PhysicalHinge` (self._root_hinge) consists of a full 3D positional and 3D angular hinge between the world origin and the ATRVjr multibody. By setting the linear and angular subhinge values, the pose, velocity, and acceleration of the multibody can be set.\n",
    "\n",
    "The StatePropagator callback functions implement applying torques and checking for termination.\n",
    "- `check_termination(t, x)` will terminate the episode if the car has flipped over or collided with the walls.\n",
    "- `pre_deriv_fn(t, x)` applies torques to move the wheels.\n",
    "\n",
    "```python\n",
    "def init_sim(self, port: int = 0, sync_real_time: bool = False):\n",
    "        # setup bodies and scene\n",
    "        self._fc = FrameContainer(\"root\")\n",
    "        self._mb = createMbody(self._fc)\n",
    "\n",
    "        # attach our multibody's to the root frame\n",
    "        self._f2f = cast(PhysicalBody, self._mb.virtualRoot()).frameToFrame(\n",
    "            self._mb.getBody(\"center_link\")\n",
    "        )\n",
    "        self._root_hinge = cast(PhysicalHinge, self._mb.getBody(\"center_link\").parentHinge())\n",
    "\n",
    "        # setup graphics if enabled\n",
    "        if self.render_mode == \"human\":\n",
    "            self._cleanup_graphics, viz_scene = self._mb.setupGraphics(axes=0.5, port=port)\n",
    "            viz_scene.defaultCamera().pointCameraAt(\n",
    "                [-0.8 * self.size, 0.0, self.size], [0, 0, 0], [0, 0, 1]\n",
    "            )\n",
    "            self._proxy_scene = self._mb.getScene()\n",
    "        else:\n",
    "            self._proxy_scene = ProxyScene(f\"{self._mb.name()}_proxyscene\", self._mb.virtualRoot())\n",
    "\n",
    "        # setup wheels, ground, and walls\n",
    "        left_front_wheel, left_rear_wheel, right_front_wheel, right_rear_wheel = createEnvironment(\n",
    "            self.size, self._mb, self._proxy_scene\n",
    "        )\n",
    "\n",
    "        # setup collision client\n",
    "        col_scene = CoalScene(\"collision_scene\")\n",
    "        self._proxy_scene.registerClientScene(\n",
    "            col_scene, self._mb.virtualRoot(), layers=LAYER_COLLISION\n",
    "        )\n",
    "\n",
    "        # setup state propagator and models\n",
    "        self._sp = StatePropagator(self._mb, integrator_type=IntegratorType.RK4)\n",
    "\n",
    "        # gravity model\n",
    "        ug = Gravity(\"grav_model\", sp, UniformGravity(\"uniform_gravity\"), mb)\n",
    "        ug.getGravityInterface().setGravity(np.array([0, 0, -9.81]), 0.0, OutputUpdateType.PRE_HOP)\n",
    "        del ug\n",
    "\n",
    "        # for visualization\n",
    "        if sync_real_time:\n",
    "            SyncRealTime(\"sync_real_time\",self._sp, 1.0)\n",
    "\n",
    "        # update proxy scene model\n",
    "        UpdateProxyScene(\"update_proxy_scene\",self._sp, self._proxy_scene)\n",
    "\n",
    "        # collision model\n",
    "        hc = HuntCrossley(\"hunt_crossley_contact\")\n",
    "        hc.params.kp = 100000\n",
    "        hc.params.kc = 20000\n",
    "        hc.params.mu = 0.3\n",
    "        hc.params.n = 1.5\n",
    "        hc.params.linear_region_tol = 1e-3\n",
    "        PenaltyContact(\"penalty_contact\", self._sp, self._mb, [FrameCollider(self._proxy_scene, col_scene)], hc)\n",
    "        del hc\n",
    "\n",
    "        def check_termination(t, x):\n",
    "            # check if the car is flipping over\n",
    "            flip = abs(self._f2f.relTransform().getUnitQuaternion().toEulerAngles().alpha()) > 0.2\n",
    "            # if more than our 4 wheels are in contact, we have a collision\n",
    "            if len(col_scene.cachedCollisions()) > 4 or flip:\n",
    "                self._terminated = True\n",
    "                return True\n",
    "            return False\n",
    "\n",
    "        def pre_deriv_fn(t, x):\n",
    "            left_front_wheel.parentHinge().subhinge(0).setT(self._left)\n",
    "            left_rear_wheel.parentHinge().subhinge(0).setT(self._left)\n",
    "            right_front_wheel.parentHinge().subhinge(0).setT(self._right)\n",
    "            right_rear_wheel.parentHinge().subhinge(0).setT(self._right)\n",
    "\n",
    "        self._sp.fns.pre_deriv_fns[\"apply_torque\"] = pre_deriv_fn\n",
    "        self._sp.fns.terminate_advance_to_fns[\"check_termination\"] = check_termination\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "043c47d3",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Setup Helpers\n",
    "These are a few helper functions for gymnasium. _get_obs composes the observation values into a numpy array for a model to train on, and _get_info is primarily for user debugging. \n",
    "\n",
    "```python\n",
    "def _get_obs(self):\n",
    "        return np.concatenate(\n",
    "            (\n",
    "                self._target_pos,\n",
    "                self._agent_pos,\n",
    "                self._agent_vel,\n",
    "                self._agent_roll,\n",
    "                self._agent_yaw_vel,\n",
    "                np.sin(self._angle_diff),\n",
    "                np.cos(self._angle_diff),\n",
    "            )\n",
    "        ).astype(np.float32)\n",
    "\n",
    "    # use non-normalized values for info\n",
    "    def _get_info(self):\n",
    "        return {\n",
    "            \"target_pos\": self._target_pos,\n",
    "            \"agent_pos\": self._agent_pos,\n",
    "            \"agent_vel\": self._agent_vel,\n",
    "            \"agent_roll\": self._agent_roll,\n",
    "            \"agent_yaw_vel\": self._agent_yaw_vel,\n",
    "            \"angle_diff\": self._angle_diff,\n",
    "        }\n",
    "\n",
    "    def close(self):\n",
    "        del self._proxy_scene, self._f2f, self._root_hinge, self._target\n",
    "        discard(self._sp)\n",
    "        if self.render_mode == \"human\":\n",
    "            self._cleanup_graphics()\n",
    "        discard(self._mb)\n",
    "        discard(self._fc)\n",
    "        assert allDestroyed()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "498b965c",
   "metadata": {},
   "source": [
    "## Reset the Environment\n",
    "As explained in Cart Pole, this method refreshes the environment to a new scenario at the start of every episode. Here, the stored attributes are reset, and the multibody is set to an arbitrary position and yaw orientation using its root hinge. At the same time, the blue target is set to a random position on the board as well. \n",
    "\n",
    "```python\n",
    "def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):\n",
    "        super().reset(seed=seed)\n",
    "        if self.np_random is None:\n",
    "            self.np_random, _ = gym.utils.seeding.np_random(seed)\n",
    "\n",
    "        # initialize values\n",
    "        self._terminated = False\n",
    "        self._left = 0.0\n",
    "        self._right = 0.0\n",
    "        self._target_pos = self.np_random.uniform(-self._bound, self._bound, size=(2,)).astype(\n",
    "            np.float32\n",
    "        )\n",
    "        self._agent_pos = self.np_random.uniform(-self._bound, self._bound, size=(2,)).astype(\n",
    "            np.float32\n",
    "        )\n",
    "        self._agent_vel = np.zeros(2, dtype=np.float32)\n",
    "        self._agent_roll = np.zeros(1, dtype=np.float32)\n",
    "        self._agent_yaw_vel = np.zeros(1, dtype=np.float32)\n",
    "        agent_yaw = self.np_random.uniform(-np.pi, np.pi, size=(1,)).astype(np.float32)\n",
    "\n",
    "        # calculate previous distance and angle difference\n",
    "        difference = self._target_pos - self._agent_pos\n",
    "        ideal_angle = np.arctan2(difference[1], difference[0])\n",
    "        self._distance = np.linalg.norm(difference)\n",
    "        self._angle_diff = wrap_to_pi(ideal_angle - agent_yaw)\n",
    "\n",
    "        # reset the rover position and velocity\n",
    "        self._mb.resetData()\n",
    "        orientation = UnitQuaternion(agent_yaw[0], [0.0, 0.0, 1.0])\n",
    "        self._root_hinge.fitQ(\n",
    "            HomTran(q=orientation, vec=[self._agent_pos[0], self._agent_pos[1], 0])\n",
    "        )\n",
    "\n",
    "        # add the target\n",
    "        if not self._target:\n",
    "            mat_info = PhysicalMaterialInfo()\n",
    "            mat_info.color = Color.BLUE\n",
    "            blue = PhysicalMaterial(mat_info)\n",
    "\n",
    "            # create the target\n",
    "            target_geom = CylinderGeometry(0.25, 0.05)\n",
    "            self._target = ProxyScenePart(\n",
    "                \"target_part\",\n",
    "                scene=self._proxy_scene,\n",
    "                geometry=target_geom,\n",
    "                material=blue,\n",
    "                layers=LAYER_GRAPHICS,\n",
    "            )\n",
    "        self._target.setTranslation([self._target_pos[0], self._target_pos[1], 0.25])\n",
    "        self._target.setUnitQuaternion(UnitQuaternion([1 / sqrt(2), 0, 0, 1 / sqrt(2)]))\n",
    "\n",
    "        # simulation\n",
    "        t_init = np.timedelta64(0, \"ns\")\n",
    "        self._sp.setTime(t_init)\n",
    "        self._sp.setState(self._sp.assembleState())\n",
    "\n",
    "        # setup timed event (every 0.01s)\n",
    "        h = np.timedelta64(int(1e7), \"ns\")\n",
    "        t = TimedEvent(\"hop_size\", h, lambda _: None, False)\n",
    "        t.period = h\n",
    "        self._sp.registerTimedEvent(t)\n",
    "        del t\n",
    "        self._proxy_scene.update()\n",
    "        return self._get_obs(), self._get_info()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba7d452f",
   "metadata": {},
   "source": [
    "<style>\n",
    "  code {\n",
    "    font-size: 16px; /* Adjust the size as needed */\n",
    "  }\n",
    "</style>\n",
    "\n",
    "## Step the Environment\n",
    "After every model decision, the simulation is advanced and the current multibody pose and velocities are queried using the FrameToFrame. These values will be passed as part of the observation.\n",
    "\n",
    "This reward function will give 4.0 for reaching the target, and -4.0 for collision. If neither of these terminating conditions are met, then it uses these criteria:\n",
    "- **progress** how much closer the car has gotten to the target compared to the last timestep\n",
    "- **angle_progress** how much closer the car's front is facing to the target  compared to the last timestep\n",
    "- **step penalty** subtract a small amount per step to incentivize the car to reach the target as soon as possible\n",
    "\n",
    "```python\n",
    "    def step(self, action: np.ndarray):\n",
    "        torque = action * self._torque_magnitude\n",
    "        self._left = torque[0]\n",
    "        self._right = torque[1]\n",
    "        self._sp.advanceBy(0.05)\n",
    "        T = self._f2f.relTransform()\n",
    "        self._agent_pos = T.getTranslation()[:2]\n",
    "        q = T.getUnitQuaternion()\n",
    "        ea = q.toEulerAngles()\n",
    "        agent_yaw = np.array([ea.gamma()])\n",
    "        self._agent_roll = np.array([ea.alpha()])\n",
    "        sp_vel = self._f2f.relSpVel().toVector6()\n",
    "        self._agent_vel = sp_vel[3:5]\n",
    "        self._agent_yaw_vel = np.array([sp_vel[2]])\n",
    "\n",
    "        # check termination\n",
    "        difference = self._target_pos - self._agent_pos\n",
    "        distance = np.linalg.norm(difference)\n",
    "        ideal_angle = np.arctan2(difference[1], difference[0])\n",
    "        angle_diff = wrap_to_pi(ideal_angle - agent_yaw)\n",
    "        reward = 0.0\n",
    "        if bool(distance < 1.0):\n",
    "            self._terminated = True\n",
    "            reward += 4.0\n",
    "        elif self._terminated:\n",
    "            reward -= 4.0\n",
    "        else:\n",
    "            # reward calculations\n",
    "            progress = self._distance - distance\n",
    "            angle_progress = np.abs(self._angle_diff) - np.abs(angle_diff)\n",
    "            reward += progress\n",
    "            reward += 4 * angle_progress\n",
    "            reward -= 0.03  # step penalty\n",
    "            self._distance = distance\n",
    "            self._angle_diff = angle_diff\n",
    "\n",
    "        # observation, reward, terminated, truncated, data\n",
    "        return self._get_obs(), float(reward), self._terminated, False, self._get_info()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4da3cebd",
   "metadata": {},
   "source": [
    "# Train the model\n",
    "\n",
    "Awesome, now that the environment is setup training can begin. This tutorial also uses [Stable Baselines3](https://stable-baselines3.readthedocs.io/en/master/), so make sure these dependencies are installed:\n",
    "```bash\n",
    "pip install stable-baselines3\n",
    "pip install gymnasium\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6766058",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from env import ATRVEnv\n",
    "import gymnasium as gym\n",
    "from stable_baselines3 import SAC\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize\n",
    "\n",
    "gym.register(\n",
    "    id=\"kdFlex-ATRVjr\",\n",
    "    entry_point=ATRVEnv,\n",
    "    max_episode_steps=200,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46acf012",
   "metadata": {},
   "source": [
    "## Load the Environment\n",
    "\n",
    "Here, we load our training environment by registering it with gymnasium and wrapping it with DummyVecEnv to be compatible with Stable Baselines3. Here, we demonstrate using a function to create multiple environments in a single vectorized environment. While DummyVecEnv will only batch them in a single process, it is also possible to parallelize across different processes using [SubProcVecEnv](https://stable-baselines.readthedocs.io/en/master/guide/vec_envs.html#subprocvecenv).\n",
    "\n",
    "There are also 2 other wrappers we use.\n",
    "- [Monitor](https://stable-baselines3.readthedocs.io/en/master/common/monitor.html) tracks key data for training\n",
    "- [VecNormalize](https://stable-baselines3.readthedocs.io/en/master/common/monitor.html) automatically normalizes the observations and reward since training models can be unstable with large or varied values. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fc99e74",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_env = env = gym.make(\"kdFlex-ATRVjr\")\n",
    "train_env = DummyVecEnv([lambda: train_env])\n",
    "train_env = VecNormalize(train_env)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a84134e4",
   "metadata": {},
   "source": [
    "## Setup Callbacks\n",
    "\n",
    "A common technique is to use a separate evaluation environment to run periodically during training and stop training if the model hasn't improved for a specified period of time to prevent overfitting. This section accomplishes this by adding an EvalCallback.\n",
    "- EvalCallback also saves the best model into ./models/best_model.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "813f556d",
   "metadata": {},
   "source": [
    "## Train a SAC Model\n",
    "\n",
    "Because the action space is continuous, a Soft Actor Critic (SAC) model is used with the default hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "575f221c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SAC(\n",
    "    \"MlpPolicy\",\n",
    "    train_env,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b59fcc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if this is running in a test environment\n",
    "if os.getenv(\"DTEST_RUNNING\", True):\n",
    "    # in a testing environment, run for less\n",
    "    model.learn(total_timesteps=1000)\n",
    "else:\n",
    "    # takes around 30 minutes\n",
    "    model.learn(total_timesteps=500000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18e1669d",
   "metadata": {},
   "source": [
    "## Test the Model\n",
    "Because the model was trained on normalized observation values, it is important that the test environment also uses normalized values. Therefore, the normalization configuration is first saved into ./models/vec_normalize.pkl before it is loaded into the test environment. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a675ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the normalization statistics\n",
    "train_env.save(\"./models/vec_normalize.pkl\")\n",
    "train_env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67aa8a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_env = gym.make(\"kdFlex-ATRVjr\", render_mode=\"human\", sync_real_time=True)\n",
    "test_env = DummyVecEnv([lambda: test_env])\n",
    "test_env = VecNormalize.load(\"./models/vec_normalize.pkl\", test_env)\n",
    "test_env.training = False\n",
    "test_env.norm_reward = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "439073f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "terminated = False\n",
    "obs = test_env.reset()\n",
    "while not terminated:\n",
    "    action, _states = model.predict(obs, deterministic=True)\n",
    "    obs, reward, terminated, info = test_env.step(action)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a02967bf",
   "metadata": {},
   "source": [
    "A working model is also saved inside ./models/demo_models if you have difficulty training the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaa0a792",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_env.close()\n",
    "test_env = gym.make(\"kdFlex-ATRVjr\", render_mode=\"human\", sync_real_time=True)\n",
    "test_env = DummyVecEnv([lambda: test_env])\n",
    "test_env = VecNormalize.load(\"./models/demo_vec_normalize.pkl\", test_env)\n",
    "test_env.training = False\n",
    "test_env.norm_reward = False\n",
    "model = SAC.load(\"./models/demo_model\", env=test_env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "333255ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cleanup\n",
    "test_env.close()\n",
    "# hard process exit\n",
    "import os\n",
    "\n",
    "os._exit(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4ca1930",
   "metadata": {},
   "source": [
    "## Summary\n",
    "By following the Gymnasium's environment standards, you have implemented and solved a more complicated driving scenario. Now you are able to utilize kdFlex's low-cost computational complexity to tackle increasingly harder problems. Good luck!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
