import numpy as np
from math import pi
from typing import cast
import gymnasium as gym
from typing import Optional

from Karana.Frame import FrameContainer
from Karana.Core import discard, allReady, allDestroyed
from Karana.Dynamics import (
    Multibody,
    HingeType,
    StatePropagator,
    TimedEvent,
    PhysicalBody,
    PhysicalHinge,
    LinearSubhinge,
    PinSubhinge,
)
from Karana.Math import UnitQuaternion, HomTran, SpatialInertia
from Karana.Scene import (
    BoxGeometry,
    CylinderGeometry,
    Color,
    PhysicalMaterialInfo,
    PhysicalMaterial,
    LAYER_GRAPHICS,
)
from Karana.Scene import ProxyScenePart, ProxyScene
from Karana.Models import Gravity, OutputUpdateType, UniformGravity, UpdateProxyScene, SyncRealTime
from Karana.Integrators import IntegratorType


# create a cart and a pole multibody system
def createMultibody(fc: FrameContainer) -> tuple[LinearSubhinge, PinSubhinge, Multibody]:
    """Create the Mutlibody.

    Parameters
    ----------
    fc : FrameContainer
        The FrameContainer the Multibody will use.

    Returns
    -------
    tuple[LinearSubhinge, PinSubhinge, Multibody]
        The subhinges associated with the cart pole system and the Multibody itself.
    """
    mb = Multibody("mb", fc)

    # add a cart of mass 1.0kg and 0.4 x 0.2 x 0.15 (L, W, H)
    cart_height = 0.15
    cart = PhysicalBody("cart", mb)
    cart.setSpatialInertia(SpatialInertia(1.0, np.zeros(3), np.eye(3)))
    slider_hinge = PhysicalHinge(mb.virtualRoot(), cart, HingeType.SLIDER)
    slider_subhinge = cast(LinearSubhinge, slider_hinge.subhinge(0))
    slider_subhinge.setUnitAxis([1, 0, 0])
    cart_to_slider = HomTran([0, 0, 0])
    cart.setBodyToJointTransform(cart_to_slider)

    # add a pole of mass 0.1 kg, and 0.02 x 0.5 (R, L)
    pole_length = 0.5
    pole = PhysicalBody("pole", mb)
    inertia_matrix = np.diag([0.0020933, 0.0020933, 0.00002])
    pole.setSpatialInertia(SpatialInertia(0.1, [0, 0, 0], inertia_matrix))
    pin_hinge = PhysicalHinge(cart, pole, HingeType.REVOLUTE)
    pole_to_pin = HomTran([0.0, 0.0, -(pole_length / 2.0)])
    pole.setBodyToJointTransform(pole_to_pin)
    cart_to_pin = HomTran([0.0, 0.0, (cart_height / 2.0)])
    pin_hinge.onode().setBodyToNodeTransform(cart_to_pin)
    pin_subhinge = cast(PinSubhinge, pin_hinge.subhinge(0))
    pin_subhinge.setUnitAxis([0, 1, 0])

    # check
    mb.ensureHealthy()
    mb.resetData()
    assert allReady()
    return slider_subhinge, pin_subhinge, mb


# add a visual rectangle and cylinder
def addGeometries(mb: Multibody, proxy_scene: ProxyScene) -> None:
    """Add geometries to the Multibody.

    Parameters
    ----------
    mb : Multibody
        The Multibody to add the geometries to.
    proxy_scene : ProxyScene
        The ProxyScene to add the geometries in.
    """
    box_geom = BoxGeometry(0.4, 0.2, 0.15)
    cylinder_geom = CylinderGeometry(0.02, 0.5)
    mat_info = PhysicalMaterialInfo()
    mat_info.color = Color.FIREBRICK
    firebrick = PhysicalMaterial(mat_info)
    mat_info.color = Color.GOLD
    gold = PhysicalMaterial(mat_info)
    mat_info.color = Color.WHITE
    white = PhysicalMaterial(mat_info)

    # add geometry to bodies
    cart_body = ProxyScenePart(
        "cart_body", scene=proxy_scene, geometry=box_geom, material=firebrick, layers=LAYER_GRAPHICS
    )
    cart_body.attachTo(mb.getBody("cart"))
    pole_body = ProxyScenePart(
        "pole_body", scene=proxy_scene, geometry=cylinder_geom, material=gold, layers=LAYER_GRAPHICS
    )
    pole_body.setUnitQuaternion(UnitQuaternion(pi / 2, [1, 0, 0]))
    pole_body.attachTo(mb.getBody("pole"))

    # add a floor
    ground_geom = BoxGeometry(20, 20, 0.5)
    ground = ProxyScenePart(
        "ground", scene=proxy_scene, geometry=ground_geom, material=white, layers=LAYER_GRAPHICS
    )
    ground.setTranslation([0, 0, -0.325])


class CartPoleEnv(gym.Env):
    """Environment for the cart pole system."""

    metadata = {"render_modes": ["human", "none"]}

    def __init__(
        self,
        render_mode: Optional[str] = None,
        port: int = 0,
        sync_real_time: bool = False,
        **kwargs,
    ):
        """Create an instance of CartPoleEnv.

        Parameters
        ----------
        render_mode : Optional[str]
            The render mode for the ProxyScene.
        port : int
            The port to use for ProxyScene.
        sync_real_time : bool
            If True, then add a SyncRealTime model. If False, then do not add it.
        **kwargs
            Extra keyword arguments.
        """
        self.render_mode = "none"
        if render_mode and render_mode in self.metadata["render_modes"]:
            self.render_mode = render_mode

        # init simulation
        self.init_sim(port=port, sync_real_time=sync_real_time)

        # observation space of [cart pos, cart vel, pin pos, pin vel]
        low = np.array([-1.5, -2, -pi / 3, -6.0], dtype=np.float32)
        high = np.array([1.5, 2, pi / 3, 6.0], dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32, shape=(4,))

        # action space of negative or positive force on cart
        self.action_space = gym.spaces.Discrete(2)
        self._action_to_force = {0: -10.0, 1: 10.0}

        # initialize values
        self._cart_pos = None
        self._cart_vel = None
        self._pin_pos = None
        self._pin_vel = None
        self.np_random = None
        self._force = None
        self._terminate = None

    def init_sim(self, port: int = 0, sync_real_time: bool = False):
        """Initialize the simulation.

        Parameters
        ----------
        port : int
            The port to use for the ProxyScene.
        sync_real_time : bool
            If True, then add a SyncRealTime model. If False, then do not add it.
        """
        self._fc = FrameContainer("root")
        self._slider, self._pin, self._mb = createMultibody(self._fc)

        # add visuals
        if self.render_mode == "human":
            self._cleanup_graphics, web_scene = self._mb.setupGraphics(port=port, axes=0)
            self._proxy_scene = self._mb.getScene()
            web_scene.defaultCamera().pointCameraAt([0, 3, 1], [0, 0, 0], [0, 0, 1])
            addGeometries(self._mb, self._proxy_scene)
        else:
            self._proxy_scene = ProxyScene(f"{self._mb.name()}_proxyscene", self._mb.virtualRoot())

        # add models
        self._sp = StatePropagator(self._mb, integrator_type=IntegratorType.RK4)
        # gravity model
        ug = Gravity("grav_model", self._sp, UniformGravity("uniform_gravity"), self._mb)
        ug.getGravityInterface().setGravity(np.array([0, 0, -9.81]), 0.0, OutputUpdateType.PRE_HOP)
        del ug

        # for visualization
        if sync_real_time:
            SyncRealTime("sync_real_time", self._sp, 1.0)

        # update proxy scene model
        UpdateProxyScene("update_proxy_scene", self._sp, self._proxy_scene)

        # apply slider force to move the cart
        def apply_force(t, x):
            self._slider.setT(self._force)

        # check if pole or cart fall past a threshold
        # max cart position is 1.5, max pole angle is ~12 (0.21)
        def check_terminate(t, x):
            cart_pos = self._slider.getQ()
            pin_pos = self._pin.getQ()
            if abs(cart_pos) > 1.5 or abs(pin_pos) > 0.21:
                self._terminate = True
                return True
            return False

        self._sp.fns.pre_deriv_fns["apply_force"] = apply_force
        self._sp.fns.terminate_advance_to_fns["check_terminate"] = check_terminate

    # what the model sees
    def _get_obs(self):
        return np.array(
            [
                float(self._cart_pos),
                float(self._cart_vel),
                float(self._pin_pos),
                float(self._pin_vel),
            ],
            dtype=np.float32,
        )

    # for user debugging
    def _get_info(self):
        return {
            "cart_pos": self._cart_pos,
            "cart_vel": self._cart_vel,
            "pin_pos": self._pin_pos,
            "pin_vel": self._pin_vel,
        }

    def close(self):
        """Close down the environment and remove the simulation."""
        del self._slider, self._pin, self._proxy_scene
        discard(self._sp)
        if self.render_mode == "human":
            self._cleanup_graphics()
        discard(self._mb)
        discard(self._fc)
        assert allDestroyed()

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """Reset the cart pole system.

        Parameters
        ----------
        seed : Optional[int]
            The seed used for random values.
        options : Optional[dict]
            Extra options used for the reset.
        """
        super().reset(seed=seed)
        if not self.np_random:
            self.np_random, _ = gym.utils.seeding.np_random(seed)

        # init values
        self._cart_pos = self.np_random.uniform(-0.1, 0.1)
        self._cart_vel = self.np_random.uniform(-0.1, 0.1)
        self._pin_pos = self.np_random.uniform(-0.1, 0.1)
        self._pin_vel = self.np_random.uniform(-0.1, 0.1)
        self._force = 0.0
        self._terminate = False

        # set our state
        self._mb.resetData()
        self._slider.setQ(self._cart_pos)
        self._slider.setU(self._cart_vel)
        self._pin.setQ(self._pin_pos)
        self._pin.setU(self._pin_vel)
        t_init = np.timedelta64(0, "ns")
        self._sp.setTime(t_init)
        self._sp.setState(self._sp.assembleState())

        # setup timed event
        h = np.timedelta64(int(1e7), "ns")
        t = TimedEvent("hop_size", h, lambda _: None, False)
        t.period = h
        self._sp.registerTimedEvent(t)
        del t
        self._proxy_scene.update()
        return self._get_obs(), self._get_info()

    def step(self, action):
        """Step the simulation.

        Parameters
        ----------
        action :
            The action to take.
        """
        self._force = self._action_to_force[action]
        self._sp.advanceBy(0.02)
        self._cart_pos = self._slider.getQ()
        self._cart_vel = self._slider.getU()
        self._pin_pos = self._pin.getQ()
        self._pin_vel = self._pin.getU()
        # observation, reward, terminated, truncated, info
        return self._get_obs(), 1.0, self._terminate, False, self._get_info()
