from typing import Optional, cast
from Karana.Scene.Scene_types import ScenePartSpecDS
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from pathlib import Path
from math import pi, sqrt

from Karana.Frame import FrameContainer
from Karana.Core import discard, allDestroyed
from Karana.Dynamics.SOADyn_types import (
    BodyWithContextDS,
    SubGraphDS,
    PinSubhingeDS,
    Linear3SubhingeDS,
    SphericalSubhingeDS,
    HingeDS,
)
from Karana.Dynamics import (
    Multibody,
    HingeType,
    PinSubhinge,
    StatePropagator,
    TimedEvent,
    PhysicalBody,
    PhysicalHinge,
)
from Karana.Math import UnitQuaternion, HomTran
from Karana.Scene import (
    BoxGeometry,
    CylinderGeometry,
    Color,
    PhysicalMaterialInfo,
    PhysicalMaterial,
    LAYER_GRAPHICS,
    LAYER_ALL,
    LAYER_COLLISION,
    Texture,
)
from Karana.Scene import ProxyScenePart, ProxyScene
from Karana.Scene import CoalScene
from Karana.Collision import FrameCollider, HuntCrossley
from Karana.Models import (
    Gravity,
    OutputUpdateType,
    UniformGravity,
    UpdateProxyScene,
    SyncRealTime,
    PenaltyContact,
)
from Karana.Integrators import IntegratorType
from Karana.KUtils.BasicPrefab import BasicPrefabDS


# helper to wrap angles
def wrapToPi(x: float) -> float:
    """Wrap angles between -pi and pi.

    Parameters
    ----------
    x : float
        The angle to wrap.

    Returns
    -------
    float
        The angle wrapped between -pi and pi.
    """
    return (x + np.pi) % (2 * np.pi) - np.pi


def createMbody(fc: FrameContainer) -> Multibody:
    """Create the Multibody.

    Parameters
    ----------
    fc : FrameContainer
        The FrameContainer the Multibody will use.

    Returns
    -------
    Multibody
        The newly created Multibody.
    """
    urdf_file = Path().absolute().parent / "resources" / "atrvjr" / "atrvjr.urdf"
    dark = BasicPrefabDS.fromFile(urdf_file).params.subtree
    multibody_info = SubGraphDS(name=dark.name, base_bodies=dark.base_bodies)
    # convert to full6DOF hinge between root frame and rover.
    multibody_info.base_bodies[0].body.hinge = HingeDS(
        hinge_type=HingeType.FULL6DOF,
        subhinges=[
            Linear3SubhingeDS(prescribed=False),
            SphericalSubhingeDS(prescribed=False),
        ],
    )

    # unlock locked Hinges for wheels from the URDF
    for i in [0, 1]:
        multibody_info.base_bodies[0].children[i].body.hinge = HingeDS(
            hinge_type=HingeType.REVOLUTE,
            subhinges=[PinSubhingeDS(prescribed=False, unit_axis=np.array([0.0, 0.0, 1.0]))],
        )
    for i in [2, 3]:
        multibody_info.base_bodies[0].children[i].body.hinge = HingeDS(
            hinge_type=HingeType.REVOLUTE,
            subhinges=[PinSubhingeDS(prescribed=False, unit_axis=np.array([0.0, 0.0, -1.0]))],
        )

    # Remove all collision meshes. We want to add these manually.
    def removeCollision():
        def _removeParts(parts: list[ScenePartSpecDS]):
            idx = []
            for k, p in enumerate(parts):
                if p.layers == LAYER_COLLISION:
                    idx.append(k)
            idx.reverse()
            for k in idx:
                parts.pop(k)

        def _recurse(b: BodyWithContextDS):
            _removeParts(b.body.scene_parts)
            for c in b.children:
                _recurse(c)

        for base in multibody_info.base_bodies:
            _recurse(base)

    removeCollision()

    multibody = multibody_info.toMultibody(fc)
    multibody.ensureHealthy()
    multibody.resetData()
    assert multibody.isReady()
    return multibody


def createEnvironment(
    size: int, mb: Multibody, proxy_scene: ProxyScene
) -> tuple[PhysicalBody, PhysicalBody, PhysicalBody, PhysicalBody]:
    """Create the environment for the ATRV Jr. sim.

    Parameters
    ----------
    size : int
        The size of the walls and ground.
    mb : Multibody
        The Multibody to create the environment for.
    proxy_scene : ProxyScene
        The ProxyScene to put the geometry in.

    Returns
    -------
    tuple[PhysicalBody, PhysicalBody, PhysicalBody, PhysicalBody]
        The 4 wheel bodies of the ATRV Jr. vehicle.
    """
    thickness = 0.25
    height = 2.5
    mat_info = PhysicalMaterialInfo()
    mat_info.color = Color.GRAY
    gray = PhysicalMaterial(mat_info)
    # add wheels to our atrv model
    wheel_geom = CylinderGeometry(0.125, 0.075)
    left_front_wheel = cast(PhysicalBody, mb.getBody("left_front_wheel_link"))
    left_front_wheel_node = ProxyScenePart(
        name="left_front_wheel_link_node", scene=proxy_scene, geometry=wheel_geom, material=gray
    )
    left_front_wheel_node.attachTo(left_front_wheel)
    left_front_wheel_node.setTranslation([0, 0, 0.03])
    left_front_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))

    left_rear_wheel = cast(PhysicalBody, mb.getBody("left_rear_wheel_link"))
    left_rear_wheel_node = ProxyScenePart(
        name="left_rear_wheel_link_node", scene=proxy_scene, geometry=wheel_geom, material=gray
    )
    left_rear_wheel_node.attachTo(left_rear_wheel)
    left_rear_wheel_node.setTranslation([0, 0, 0.03])
    left_rear_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))

    right_front_wheel = cast(PhysicalBody, mb.getBody("right_front_wheel_link"))
    right_front_wheel_node = ProxyScenePart(
        name="right_front_wheel_link_node", scene=proxy_scene, geometry=wheel_geom, material=gray
    )
    right_front_wheel_node.attachTo(right_front_wheel)
    right_front_wheel_node.setTranslation([0, 0, 0.03])
    right_front_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))

    right_rear_wheel = cast(PhysicalBody, mb.getBody("right_rear_wheel_link"))
    right_rear_wheel_node = ProxyScenePart(
        name="right_rear_wheel_link_node", scene=proxy_scene, geometry=wheel_geom, material=gray
    )
    right_rear_wheel_node.attachTo(right_rear_wheel)
    right_rear_wheel_node.setTranslation([0, 0, 0.03])
    right_rear_wheel_node.setUnitQuaternion(UnitQuaternion(pi / 2, [1.0, 0.0, 0.0]))

    # add ground
    ground_geom = BoxGeometry(size, size, thickness)
    mat_info.color = Color.WHITE
    texture_file = Path().absolute().parent / "resources" / "atrvjr" / "grid.png"
    grid_texture = Texture.lookupOrCreateTexture(str(texture_file))
    mat_info.color_map = grid_texture
    grid = PhysicalMaterial(mat_info)
    ground = ProxyScenePart(
        "ground", scene=proxy_scene, geometry=ground_geom, material=grid, layers=LAYER_ALL
    )
    ground.setTranslation([0, 0, -0.1])

    # add walls
    mat_info.color = Color.KHAKI
    mat_info.color_map = None
    khaki = PhysicalMaterial(mat_info)
    wall_geom = BoxGeometry(size, thickness, height)
    north_wall = ProxyScenePart(
        "north_wall", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL
    )
    north_wall.setTranslation([size / 2.0 + thickness / 2, 0, height / 2 - thickness])
    north_wall.setUnitQuaternion(UnitQuaternion([1 / sqrt(2), 1 / sqrt(2), 0, 0]))
    south_wall = ProxyScenePart(
        "south_wall", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL
    )
    south_wall.setTranslation([-size / 2.0 - thickness / 2, 0, height / 2 - thickness])
    south_wall.setUnitQuaternion(UnitQuaternion([-1 / sqrt(2), 1 / sqrt(2), 0, 0]))
    west_wall = ProxyScenePart(
        "west_wall", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL
    )
    west_wall.setTranslation([0, size / 2.0 + thickness / 2, height / 2 - thickness])
    east_wall = ProxyScenePart(
        "east_wall", scene=proxy_scene, geometry=wall_geom, material=khaki, layers=LAYER_ALL
    )
    east_wall.setTranslation([0, -size / 2.0 - thickness / 2, height / 2 - thickness])

    return left_front_wheel, left_rear_wheel, right_front_wheel, right_rear_wheel


class ATRVEnv(gym.Env):
    """Environment for the ATRV Jr. system."""

    metadata = {"render_modes": ["human", "none"]}

    def __init__(
        self,
        render_mode: Optional[str] = None,
        port: int = 0,
        sync_real_time: bool = False,
        **_,
    ):
        """Create an instance of ATRVEnv.

        Parameters
        ----------
        render_mode : Optional[str]
            The render mode for the ProxyScene.
        port : int
            The port to use for ProxyScene.
        sync_real_time : bool
            If True, then add a SyncRealTime model. If False, then do not add it.
        **_
            Extra keyword arguments.
        """
        self.render_mode = "none"
        if render_mode and render_mode in self.metadata["render_modes"]:
            self.render_mode = render_mode
        self.size = 30
        self._bound = self.size / 2.0 - 3.0
        self._torque_magnitude = 20.0

        # initialize the simulation
        self.init_sim(port, sync_real_time)

        # [txpos, typos, x, y, vx, vy, roll, yaw_vel, sin(angle_diff), cos(angle_diff)]
        low = np.array(
            [
                -self._bound,
                -self._bound,
                -self.size,
                -self.size,
                -6.0,
                -6.0,
                -0.25,
                -6.0,
                -1.0,
                -1.0,
            ],
            dtype=np.float32,
        )
        high = np.array(
            [self._bound, self._bound, self.size, self.size, 6.0, 6.0, 0.25, 6.0, 1.0, 1.0],
            dtype=np.float32,
        )
        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)

        # define action space
        # torque on [left, right]
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)

        # initialize class variables
        self._target = None
        self._left: float
        self._right: float
        self._target_pos: NDArray[np.float32]
        self._agent_pos: NDArray[np.float32 | np.float64]
        self._agent_vel: NDArray[np.float32 | np.float64]
        self._agent_roll: NDArray[np.float32]
        self._agent_yaw_vel: NDArray[np.float32]
        self._angle_diff: float
        self.np_random = None

    def init_sim(self, port: int = 0, sync_real_time: bool = False):
        """Initialize the simulation.

        Parameters
        ----------
        port : int
            The port to use for the ProxyScene.
        sync_real_time : bool
            If True, then add a SyncRealTime model. If False, then do not add it.
        """
        # setup bodies and scene
        self._fc = FrameContainer("root")
        self._mb = createMbody(self._fc)

        # attach our multibody's to the root frame
        self._f2f = cast(PhysicalBody, self._mb.virtualRoot()).frameToFrame(
            cast(PhysicalBody, self._mb.getBody("center_link"))
        )
        self._root_hinge = cast(PhysicalHinge, self._mb.getBody("center_link").parentHinge())

        # setup graphics if enabled
        if self.render_mode == "human":
            self._cleanup_graphics, viz_scene = self._mb.setupGraphics(axes=0.5, port=port)
            viz_scene.defaultCamera().pointCameraAt(
                [-0.8 * self.size, 0.0, self.size], [0, 0, 0], [0, 0, 1]
            )
            self._proxy_scene = self._mb.getScene()
            if self._proxy_scene is None:
                raise ValueError("Scene not defined.")
        else:
            self._proxy_scene = ProxyScene(
                f"{self._mb.name()}_proxyscene", cast(PhysicalBody, self._mb.virtualRoot())
            )

        # setup wheels, ground, and walls
        left_front_wheel, left_rear_wheel, right_front_wheel, right_rear_wheel = createEnvironment(
            self.size, self._mb, self._proxy_scene
        )

        # setup collision client
        col_scene = CoalScene("collision_scene")
        self._proxy_scene.registerClientScene(
            col_scene, cast(PhysicalBody, self._mb.virtualRoot()), layers=LAYER_COLLISION
        )

        # setup state propagator and models
        self._sp = StatePropagator(self._mb, integrator_type=IntegratorType.RK4)

        # gravity model
        ug = Gravity("grav_model", self._sp, UniformGravity("uniform_gravity"), self._mb)
        ug.getGravityInterface().setGravity(np.array([0, 0, -9.81]), 0.0, OutputUpdateType.PRE_HOP)
        del ug

        # for visualization
        if sync_real_time:
            SyncRealTime("sync_real_time", self._sp, 1.0)

        # update proxy scene model
        UpdateProxyScene("update_proxy_scene", self._sp, self._proxy_scene)

        # collision model
        hc = HuntCrossley("hunt_crossley_contact")
        hc.params.kp = 100000
        hc.params.kc = 20000
        hc.params.mu = 0.3
        hc.params.n = 1.5
        hc.params.linear_region_tol = 1e-3
        mdl = PenaltyContact(
            "penalty_contact", self._sp, self._mb, [FrameCollider(self._proxy_scene, col_scene)], hc
        )
        del mdl, hc

        def check_termination(_, x):
            # check if the car is flipping over
            flip = abs(self._f2f.relTransform().getUnitQuaternion().toEulerAngles().alpha()) > 0.2
            # if more than our 4 wheels are in contact, we have a collision
            if len(col_scene.cachedCollisions()) > 4 or flip:
                self._terminated = True
                return True
            return False

        def pre_deriv_fn(t, _):
            cast(PinSubhinge, left_front_wheel.parentHinge().subhinge(0)).setT(self._left)
            cast(PinSubhinge, left_rear_wheel.parentHinge().subhinge(0)).setT(self._left)
            cast(PinSubhinge, right_front_wheel.parentHinge().subhinge(0)).setT(self._right)
            cast(PinSubhinge, right_rear_wheel.parentHinge().subhinge(0)).setT(self._right)

        self._sp.fns.pre_deriv_fns["apply_torque"] = pre_deriv_fn
        self._sp.fns.terminate_advance_to_fns["check_termination"] = check_termination

    def _get_obs(self):
        return np.concatenate(
            (
                self._target_pos,
                self._agent_pos,
                self._agent_vel,
                self._agent_roll,
                self._agent_yaw_vel,
                np.sin(self._angle_diff),
                np.cos(self._angle_diff),
            )
        ).astype(np.float32)

    # use non-normalized values for info
    def _get_info(self):
        return {
            "target_pos": self._target_pos,
            "agent_pos": self._agent_pos,
            "agent_vel": self._agent_vel,
            "agent_roll": self._agent_roll,
            "agent_yaw_vel": self._agent_yaw_vel,
            "angle_diff": self._angle_diff,
        }

    def close(self):
        """Close down the environment and remove the simulation."""
        del self._f2f, self._root_hinge, self._target
        discard(self._sp)
        if self.render_mode == "human":
            del self._proxy_scene
            self._cleanup_graphics()
        else:
            self._mb.setScene(None)
            discard(self._proxy_scene)
        discard(self._mb)
        discard(self._fc)
        assert allDestroyed()

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """Reset the ATRV Jr. system.

        Parameters
        ----------
        seed : Optional[int]
            The seed used for random values.
        options : Optional[dict]
            Extra options used for the reset.
        """
        super().reset(seed=seed)
        if self.np_random is None:
            self.np_random, _ = gym.utils.seeding.np_random(seed)

        # initialize values
        self._terminated = False
        self._left = 0.0
        self._right = 0.0
        self._target_pos = self.np_random.uniform(-self._bound, self._bound, size=(2,)).astype(
            np.float32
        )
        self._agent_pos = self.np_random.uniform(-self._bound, self._bound, size=(2,)).astype(
            np.float32
        )
        self._agent_vel = np.zeros(2, dtype=np.float32)
        self._agent_roll = np.zeros(1, dtype=np.float32)
        self._agent_yaw_vel = np.zeros(1, dtype=np.float32)
        agent_yaw = self.np_random.uniform(-np.pi, np.pi, size=(1,)).astype(np.float32)

        # calculate previous distance and angle difference
        difference = self._target_pos - self._agent_pos
        ideal_angle = np.arctan2(difference[1], difference[0])
        self._distance = np.linalg.norm(difference)
        self._angle_diff = wrapToPi(ideal_angle - agent_yaw)

        # reset the rover position and velocity
        self._mb.resetData()
        orientation = UnitQuaternion(agent_yaw[0], [0.0, 0.0, 1.0])
        self._root_hinge.fitQ(
            HomTran(q=orientation, vec=[self._agent_pos[0], self._agent_pos[1], 0])
        )

        # add the target
        if not self._target:
            mat_info = PhysicalMaterialInfo()
            mat_info.color = Color.BLUE
            blue = PhysicalMaterial(mat_info)

            # create the target
            target_geom = CylinderGeometry(0.25, 0.05)
            self._target = ProxyScenePart(
                "target_part",
                scene=self._proxy_scene,
                geometry=target_geom,
                material=blue,
                layers=LAYER_GRAPHICS,
            )
        self._target.setTranslation([self._target_pos[0], self._target_pos[1], 0.25])
        self._target.setUnitQuaternion(UnitQuaternion([1 / sqrt(2), 0, 0, 1 / sqrt(2)]))

        # simulation
        t_init = np.timedelta64(0, "ns")
        self._sp.setTime(t_init)
        self._sp.setState(self._sp.assembleState())

        # setup timed event (every 0.01s)
        h = np.timedelta64(int(1e7), "ns")
        t = TimedEvent("hop_size", h, lambda _: None, False)
        t.period = h
        self._sp.registerTimedEvent(t)
        del t
        self._proxy_scene.update()
        return self._get_obs(), self._get_info()

    def step(self, action: np.ndarray):
        """Step the simulation.

        Parameters
        ----------
        action :
            The action to take.
        """
        torque = action * self._torque_magnitude
        self._left = torque[0]
        self._right = torque[1]
        self._sp.advanceBy(0.05)
        T = self._f2f.relTransform()
        self._agent_pos = T.getTranslation().m[:2]
        q = T.getUnitQuaternion()
        ea = q.toEulerAngles()
        agent_yaw = np.array([ea.gamma()])
        self._agent_roll = np.array([ea.alpha()])
        sp_vel = self._f2f.relSpVel().toVector6()
        self._agent_vel = sp_vel[3:5]
        self._agent_yaw_vel = np.array([sp_vel[2]])

        # check termination
        difference = self._target_pos - self._agent_pos
        distance = np.linalg.norm(difference)
        ideal_angle = np.arctan2(difference[1], difference[0])
        angle_diff = wrapToPi(ideal_angle - agent_yaw)
        reward = 0.0
        if bool(distance < 1.0):
            self._terminated = True
            reward += 4.0
        elif self._terminated:
            reward -= 4.0
        else:
            # reward calculations
            progress = self._distance - distance
            angle_progress = np.abs(self._angle_diff) - np.abs(angle_diff)
            reward += progress
            reward += 4 * angle_progress
            reward -= 0.03  # step penalty
            self._distance = distance
            self._angle_diff = angle_diff

        # observation, reward, terminated, truncated, data
        return self._get_obs(), float(reward), self._terminated, False, self._get_info()
