Program Listing for File pybind11Kquantities.h

Program Listing for File pybind11Kquantities.h#

Return to documentation for file (include/Karana/Math/pybind11Kquantities.h)

/*
 * Copyright (c) 2024-2026 Karana Dynamics Pty Ltd. All rights reserved.
 *
 * NOTICE TO USER:
 *
 * This source code and/or documentation (the "Licensed Materials") is
 * the confidential and proprietary information of Karana Dynamics Inc.
 * Use of these Licensed Materials is governed by the terms and conditions
 * of a separate software license agreement between Karana Dynamics and the
 * Licensee ("License Agreement"). Unless expressly permitted under that
 * agreement, any reproduction, modification, distribution, or disclosure
 * of the Licensed Materials, in whole or in part, to any third party
 * without the prior written consent of Karana Dynamics is strictly prohibited.
 *
 * THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND.
 * KARANA DYNAMICS DISCLAIMS ALL WARRANTIES, EXPRESS OR IMPLIED, INCLUDING
 * BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT, AND
 * FITNESS FOR A PARTICULAR PURPOSE.
 *
 * IN NO EVENT SHALL KARANA DYNAMICS BE LIABLE FOR ANY DAMAGES WHATSOEVER,
 * INCLUDING BUT NOT LIMITED TO LOSS OF PROFITS, DATA, OR USE, EVEN IF
 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGES, WHETHER IN CONTRACT, TORT,
 * OR OTHERWISE ARISING OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS.
 *
 * U.S. Government End Users: The Licensed Materials are a "commercial item"
 * as defined at 48 C.F.R. 2.101, and are provided to the U.S. Government
 * only as a commercial end item under the terms of this license.
 *
 * Any use of the Licensed Materials in individual or commercial software must
 * include, in the user documentation and internal source code comments,
 * this Notice, Disclaimer, and U.S. Government Use Provision.
 */

#pragma once

#include "Karana/Math/Defs.h"
#include "Karana/Math/SpatialVector.h"
#include <format>
#include <pybind11/eigen.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace km = Karana::Math;

namespace Karana::Math {
    /**
     * @struct FloatQuantityHelper
     * @brief Simple wrapper struct to bring Length into pybind11.
     */
    struct FloatQuantityHelper {
        /**
         * @brief Default constructor. Used by pybind11 caster.
         */
        FloatQuantityHelper(){};

        /**
         * @brief Constructor with Length value as a double.
         *
         * @param v The value to store.
         */
        FloatQuantityHelper(double v)
            : value(v){};

        /// The value expressed as a double.
        double value;

        /**
         * @brief Conversion operator to double.
         *
         * @return The internal value.
         */
        operator double() const { return value; }
    };

    /**
     * @struct Length
     * @brief Simple wrapper struct to bring Length into pybind11.
     */
    struct Length : FloatQuantityHelper {};

    /**
     * @struct Angle
     * @brief Simple wrapper struct to bring Angle into pybind11.
     */
    struct Angle : FloatQuantityHelper {};

    /**
     * @struct Mass
     * @brief Simple wrapper struct to bring Mass into pybind11.
     */
    struct Mass : FloatQuantityHelper {};

    /**
     * @struct GravitationalParameter
     * @brief Simple wrapper struct to bring GravitationalParameter into pybind11.
     */
    struct GravitationalParameter : FloatQuantityHelper {};

    /**
     * @struct Vec3QuantityHelper
     * @brief Simple wrapper struct to bring Vec3 quantities into pybind11.
     */
    struct Vec3QuantityHelper {
        /**
         * @brief Default constructor. Used by pybind11 caster.
         */
        Vec3QuantityHelper(){};

        /**
         * @brief Constructor with Length value as a double.
         *
         * @param v The value to store.
         */
        Vec3QuantityHelper(Vec3 v)
            : value(v){};

        /// The value expressed as a double.
        Vec3 value;

        /**
         * @brief Conversion operator to double.
         *
         * @return The internal value.
         */
        operator Vec3() const { return value; }
    };

    /**
     * @struct Length3
     * @brief Simple wrapper struct to bring Length3 into pybind11.
     */
    struct Length3 : Vec3QuantityHelper {};

    /**
     * @struct Angle3
     * @brief Simple wrapper struct to bring Angle3 into pybind11.
     */
    struct Angle3 : Vec3QuantityHelper {};

    /**
     * @struct Velocity3
     * @brief Simple wrapper struct to bring Velocity3 into pybind11.
     */
    struct Velocity3 : Vec3QuantityHelper {};

    /**
     * @struct AngularVelocity3
     * @brief Simple wrapper struct to bring AngularVelocity3 into pybind11.
     */
    struct AngularVelocity3 : Vec3QuantityHelper {};

    /**
     * @struct Acceleration3
     * @brief Simple wrapper struct to bring Acceleration3 into pybind11.
     */
    struct Acceleration3 : Vec3QuantityHelper {};

    /**
     * @struct AngularAcceleration3
     * @brief Simple wrapper struct to bring AngularAcceleration3 into pybind11.
     */
    struct AngularAcceleration3 : Vec3QuantityHelper {};

    /**
     * @struct Force3
     * @brief Simple wrapper struct to bring Force3 into pybind11.
     */
    struct Force3 : Vec3QuantityHelper {};

    /**
     * @struct Torque3
     * @brief Simple wrapper struct to bring Torque3 into pybind11.
     */
    struct Torque3 : Vec3QuantityHelper {};

    /**
     * @struct Momentum3
     * @brief Simple wrapper struct to bring Momentum3 into pybind11.
     */
    struct Momentum3 : Vec3QuantityHelper {};

    /**
     * @struct AngularMomentum3
     * @brief Simple wrapper struct to bring AngularMomentum3 into pybind11.
     */
    struct AngularMomentum3 : Vec3QuantityHelper {};

    /**
     * @struct Mat33QuantityHelper
     * @brief Simple wrapper struct to bring Mat33 quantities into pybind11.
     */
    struct Mat33QuantityHelper {
        /**
         * @brief Default constructor. Used by pybind11 caster.
         */
        Mat33QuantityHelper(){};

        /**
         * @brief Constructor with Length value as a double.
         *
         * @param v The value to store.
         */
        Mat33QuantityHelper(Mat33 v)
            : value(v){};

        /// The value expressed as a double.
        Mat33 value;

        /**
         * @brief Conversion operator to double.
         *
         * @return The internal value.
         */
        operator Mat33() const { return value; }
    };

    /**
     * @struct AngularAcceleration3
     * @brief Simple wrapper struct to bring AngularAcceleration3 into pybind11.
     */
    struct Inertia : Mat33QuantityHelper {};

} // namespace Karana::Math

namespace pybind11::detail {

#define FLOAT_QUANTITY(cpp_type, py_type, quantity_name)                                           \
    template <> struct type_caster<cpp_type> {                                                     \
      public:                                                                                      \
        PYBIND11_TYPE_CASTER(cpp_type,                                                             \
                             pybind11::detail::io_name(                                            \
                                 "Union[typing.SupportsFloat | typing.SupportsIndex | " #py_type   \
                                 "]",                                                              \
                                 #py_type));                                                       \
                                                                                                   \
        /* Conversion from Python -> C++ */                                                        \
        bool load(py::handle src, bool convert) {                                                  \
            if (!src)                                                                              \
                return false;                                                                      \
                                                                                                   \
            /* Imports for conversion */                                                           \
            py::object py_quantity = py::module_::import("pint").attr("Quantity");                 \
            py::object py_length =                                                                 \
                py::module_::import("Karana.Math.Kquantities").attr(#quantity_name);               \
                                                                                                   \
            py::object obj = py::reinterpret_borrow<py::object>(src);                              \
                                                                                                   \
            if (py::isinstance(obj, py_quantity)) {                                                \
                if (py::cast<bool>(obj.attr("check")(py_length))) {                                \
                    /* It's a quantity of the correct type, convert to double */                   \
                    value.value = py::cast<double>(obj.attr("to_base_units")().attr("m"));         \
                } else {                                                                           \
                    /* This is the wrong type, do not convert */                                   \
                    return false;                                                                  \
                }                                                                                  \
            } else {                                                                               \
                /* Assume it's a regular float */                                                  \
                detail::type_caster<double> double_caster;                                         \
                if (not double_caster.load(src, convert)) {                                        \
                    return false;                                                                  \
                }                                                                                  \
                value.value = std::move(static_cast<double &>(double_caster));                     \
            }                                                                                      \
                                                                                                   \
            return true;                                                                           \
        }                                                                                          \
                                                                                                   \
        /* Conversion C++ -> Python is double -> Quantity */                                       \
        static py::handle cast(cpp_type src, py::return_value_policy, py::handle) {                \
            py::object kq = py::module_::import("Karana.Math.Kquantities");                        \
            py::object res =                                                                       \
                py::float_(src.value) * kq.attr("getDefaultUnits")(kq.attr(#quantity_name));       \
            return res.release();                                                                  \
        }                                                                                          \
    };

    FLOAT_QUANTITY(Karana::Math::Length, Karana.Math.Ktyping.Length, length);
    FLOAT_QUANTITY(Karana::Math::Angle, Karana.Math.Ktyping.Angle, angle);
    FLOAT_QUANTITY(Karana::Math::Mass, Karana.Math.Ktyping.Mass, mass);
    FLOAT_QUANTITY(Karana::Math::GravitationalParameter,
                   Karana.Math.Ktyping.GravitationalParameter,
                   gravitational_parameter);

#define VEC3_QUANTITY(cpp_type, py_type, quantity_name)                                            \
    template <> struct type_caster<cpp_type> {                                                     \
      public:                                                                                      \
        PYBIND11_TYPE_CASTER(cpp_type,                                                             \
                             pybind11::detail::io_name("Union[numpy.typing.ArrayLike|" #py_type    \
                                                       "]",                                        \
                                                       #py_type));                                 \
                                                                                                   \
        /* Conversion from Python -> C++ */                                                        \
        bool load(py::handle src, bool convert) {                                                  \
                                                                                                   \
            if (!src)                                                                              \
                return false;                                                                      \
                                                                                                   \
            /* Imports for conversion */                                                           \
            py::object py_quantity = py::module_::import("pint").attr("Quantity");                 \
            py::object py_length =                                                                 \
                py::module_::import("Karana.Math.Kquantities").attr(#quantity_name);               \
                                                                                                   \
            py::object obj = py::reinterpret_borrow<py::object>(src);                              \
                                                                                                   \
            if (py::isinstance(obj, py_quantity)) {                                                \
                                                                                                   \
                if (py::cast<bool>(obj.attr("check")(py_length))) {                                \
                                                                                                   \
                    /* It's a quantity of the correct type, convert to Vec3. */                    \
                    value.value = py::cast<km::Vec3>(obj.attr("to_base_units")().attr("m"));       \
                } else {                                                                           \
                    /* This is the wrong type, do not convert */                                   \
                    return false;                                                                  \
                }                                                                                  \
            } else { /* Assume it's a regular Vec3 */                                              \
                detail::type_caster<km::Vec3> vec_caster;                                          \
                if (!vec_caster.load(src, convert)) {                                              \
                    return false;                                                                  \
                }                                                                                  \
                value.value = std::move(static_cast<km::Vec3 &>(vec_caster));                      \
            }                                                                                      \
            return true;                                                                           \
        }                                                                                          \
                                                                                                   \
        /* Conversion C++ -> Python is Vec3 -> Quantity */                                         \
        static py::handle cast(cpp_type src, py::return_value_policy, py::handle) {                \
                                                                                                   \
            py::object kq = py::module_::import("Karana.Math.Kquantities");                        \
                                                                                                   \
            py::object res =                                                                       \
                py::cast(src.value) * kq.attr("getDefaultUnits")(kq.attr(#quantity_name));         \
            return res.release();                                                                  \
        }                                                                                          \
    };

    VEC3_QUANTITY(Karana::Math::Length3, Karana.Math.Ktyping.Length3, length);
    VEC3_QUANTITY(Karana::Math::Angle3, Karana.Math.Ktyping.Angle3, angle);
    VEC3_QUANTITY(Karana::Math::Velocity3, Karana.Math.Ktyping.Velocity3, velocity);
    VEC3_QUANTITY(Karana::Math::Acceleration3, Karana.Math.Ktyping.Acceleration3, acceleration);
    VEC3_QUANTITY(Karana::Math::AngularVelocity3,
                  Karana.Math.Ktyping.AngularVelocity3,
                  angular_velocity);
    VEC3_QUANTITY(Karana::Math::AngularAcceleration3,
                  Karana.Math.Ktyping.AngularAcceleration3,
                  angular_acceleration);
    VEC3_QUANTITY(Karana::Math::Force3, Karana.Math.Ktyping.Force3, force);
    VEC3_QUANTITY(Karana::Math::Torque3, Karana.Math.Ktyping.Torque3, torque);
    VEC3_QUANTITY(Karana::Math::Momentum3, Karana.Math.Ktyping.Momentum3, momentum);
    VEC3_QUANTITY(Karana::Math::AngularMomentum3,
                  Karana.Math.Ktyping.AngularMomentum3,
                  angular_momentum);

#define MAT33_QUANTITY(cpp_type, py_type, quantity_name)                                           \
    template <> struct type_caster<cpp_type> {                                                     \
      public:                                                                                      \
        PYBIND11_TYPE_CASTER(cpp_type,                                                             \
                             pybind11::detail::io_name("Union[numpy.typing.ArrayLike|" #py_type    \
                                                       "]",                                        \
                                                       #py_type));                                 \
                                                                                                   \
        /* Conversion from Python -> C++ */                                                        \
        bool load(py::handle src, bool convert) {                                                  \
                                                                                                   \
            if (!src)                                                                              \
                return false;                                                                      \
                                                                                                   \
            /* Imports for conversion */                                                           \
            py::object py_quantity = py::module_::import("pint").attr("Quantity");                 \
            py::object py_length =                                                                 \
                py::module_::import("Karana.Math.Kquantities").attr(#quantity_name);               \
                                                                                                   \
            py::object obj = py::reinterpret_borrow<py::object>(src);                              \
                                                                                                   \
            if (py::isinstance(obj, py_quantity)) {                                                \
                                                                                                   \
                if (py::cast<bool>(obj.attr("check")(py_length))) {                                \
                                                                                                   \
                    /* It's a quantity of the correct type, convert to Mat33. */                   \
                    value.value = py::cast<km::Mat33>(obj.attr("to_base_units")().attr("m"));      \
                } else {                                                                           \
                    /* This is the wrong type, do not convert */                                   \
                    return false;                                                                  \
                }                                                                                  \
            } else { /* Assume it's a regular Mat33 */                                             \
                detail::type_caster<km::Mat33> mat_caster;                                         \
                if (!mat_caster.load(src, convert)) {                                              \
                    return false;                                                                  \
                }                                                                                  \
                value.value = std::move(static_cast<km::Mat33 &>(mat_caster));                     \
            }                                                                                      \
            return true;                                                                           \
        }                                                                                          \
                                                                                                   \
        /* Conversion C++ -> Python is Mat33 -> Quantity */                                        \
        static py::handle cast(cpp_type src, py::return_value_policy, py::handle) {                \
                                                                                                   \
            py::object kq = py::module_::import("Karana.Math.Kquantities");                        \
                                                                                                   \
            py::object res =                                                                       \
                py::cast(src.value) * kq.attr("getDefaultUnits")(kq.attr(#quantity_name));         \
            return res.release();                                                                  \
        }                                                                                          \
    };

    MAT33_QUANTITY(Karana::Math::Inertia, Karana.Math.Ktyping.Inertia, inertia);

} // namespace pybind11::detail