Source code for Karana.KUtils.DataStruct

# Copyright (c) 2024-2025 Karana Dynamics Pty Ltd. All rights reserved.
#
# NOTICE TO USER:
#
# This source code and/or documentation (the "Licensed Materials") is
# the confidential and proprietary information of Karana Dynamics Inc.
# Use of these Licensed Materials is governed by the terms and conditions
# of a separate software license agreement between Karana Dynamics and the
# Licensee ("License Agreement"). Unless expressly permitted under that
# agreement, any reproduction, modification, distribution, or disclosure
# of the Licensed Materials, in whole or in part, to any third party
# without the prior written consent of Karana Dynamics is strictly prohibited.
#
# THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND.
# KARANA DYNAMICS DISCLAIMS ALL WARRANTIES, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT, AND
# FITNESS FOR A PARTICULAR PURPOSE.
#
# IN NO EVENT SHALL KARANA DYNAMICS BE LIABLE FOR ANY DAMAGES WHATSOEVER,
# INCLUDING BUT NOT LIMITED TO LOSS OF PROFITS, DATA, OR USE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGES, WHETHER IN CONTRACT, TORT,
# OR OTHERWISE ARISING OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS.
#
# U.S. Government End Users: The Licensed Materials are a "commercial item"
# as defined at 48 C.F.R. 2.101, and are provided to the U.S. Government
# only as a commercial end item under the terms of this license.
#
# Any use of the Licensed Materials in individual or commercial software must
# include, in the user documentation and internal source code comments,
# this Notice, Disclaimer, and U.S. Government Use Provision.

"""Classes and functions related to DataStruct.

DataStruct is a wrapper around `pydantic.BaseModel`. It contains extra functionality to:
    * Easily save/load `DataStruct`s from various file types.
    * Compare `DataStruct`s with one another either identitcally, using `__eq__`, or 
      approximately, using `isApprox`.
    * Populate a Kclick CLI applications options from a `DataStruct`
    * Create an instance of a `DataStruct` from Kclick CLI application's options.
    * Generate an asciidoc with the data of a `DataStruct`.

The saving/loading and comparison can be done recusivley, meaning they accept `DataStruct`s 
whose fields are other `DataStructs`. In addition, they supported nested Python types, such as 
a dict of lists of numpy arrays.
"""

import h5py
import numpy as np
from numpy.typing import NDArray
import quantities as pq
from importlib import import_module
from inspect import getmodule
from pydantic import (
    BaseModel,
    ConfigDict,
    model_serializer,
    PrivateAttr,
    model_validator,
    ModelWrapValidatorHandler,
    Field,
)
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from numpydoc.docscrape import NumpyDocString
from typing import (
    Any,
    Optional,
    ClassVar,
    Self,
    TypeVar,
    cast,
    overload,
    IO,
    Literal,
    Callable,
    Generic,
)
from pathlib import Path
import click
from copy import deepcopy
from ruamel.yaml import YAML, RoundTripRepresenter
from json import dumps, JSONEncoder, load
from Karana.Math import MATH_EPSILON
from Karana.Core import warn, Base, CppWeakRef


def _h5RecurseSave(pg: h5py.Group, name: Optional[str], data: Any):
    """Recursively save some data to HDF5.

    Parameters
    ----------
    pg : h5py.Group
        The parent group to attach this data to.
    name : Optional[str]
        The name to call the data. This will be the name of
        the child group or dataset, depending on the data type.
        If it is None, then we want to write the data into the
        current group we have.
    data : Any
        The data to save.
    """
    if isinstance(data, dict):
        if name is not None:
            g = pg.create_group(name)
        else:
            g = pg
        g.attrs["PY_TYPE"] = "dict"
        for k, v in data.items():
            _h5RecurseSave(g, k, v)
    elif isinstance(data, list):
        if name is not None:
            g = pg.create_group(name)
        else:
            g = pg
        g.attrs["PY_TYPE"] = "list"
        g.attrs["NUM_ELEMENTS"] = len(data)
        for k, d in enumerate(data):
            name = f"element_{k}"
            _h5RecurseSave(g, name, d)
    elif isinstance(data, tuple):
        if name is not None:
            g = pg.create_group(name)
        else:
            g = pg
        g.attrs["PY_TYPE"] = "tuple"
        g.attrs["NUM_ELEMENTS"] = len(data)
        for k, d in enumerate(data):
            name = f"element_{k}"
            _h5RecurseSave(g, name, d)
    elif isinstance(data, set):
        if name is not None:
            g = pg.create_group(name)
        else:
            g = pg
        g.attrs["PY_TYPE"] = "set"
        g.attrs["NUM_ELEMENTS"] = len(data)
        for k, d in enumerate(data):
            name = f"element_{k}"
            _h5RecurseSave(g, name, d)
    elif isinstance(data, pq.Quantity):
        dataset = pg.create_dataset(name, data=data)
        dataset.attrs["UNIT"] = str(data.units).split()[1]
        dataset.attrs["PY_TYPE"] = "Quantity"
    elif isinstance(data, np.ndarray):
        dataset = pg.create_dataset(name, data=data)
        dataset.attrs["PY_TYPE"] = "np.array"
    elif hasattr(data, "__getstate__") and hasattr(data, "__setstate__"):
        # If this is a picklealbe object, then we use the information there.
        if name is not None:
            g = pg.create_group(name)
        else:
            g = pg

        mod = getmodule(data.__class__)
        if mod is None:
            raise ValueError("Issues getting fully qualified string for object.")
        name = cast(str, data.__class__.__name__)  # None is not possible, as this is a class
        py_type = mod.__name__ + "." + name
        if hasattr(data, "name"):
            if callable(data.name):
                name = cast(str, data.name())
            else:
                name = data.name
        else:
            name = py_type
        g.attrs["PY_TYPE"] = py_type
        g.attrs["DATA_NAME"] = name
        _h5RecurseSave(g, name, data.__getstate__())
    elif isinstance(data, str):
        # We give this special attention, as string will reload as bytes, so we need to
        # convert it.
        dataset = pg.create_dataset(name, data=data)
        dataset.attrs["PY_TYPE"] = "str"
    elif data is None:
        dataset = pg.create_dataset(name, data="PY_NONE")
        dataset.attrs["PY_TYPE"] = "None"
    else:
        # This is just a normal python object. Save it to a dataset.
        dataset = pg.create_dataset(name, data=data)


def _h5RecurseLoad(g: h5py.Group | h5py.Dataset):
    """Recursively load some data to HDF5.

    Parameters
    ----------
    g : h5py.Group | h5py.Dataset
        The group or dataset to load.
    """
    py_type = g.attrs.get("PY_TYPE", None)
    if isinstance(g, h5py.Group):
        if py_type == "dict":
            dark = {}
            for k, v in g.items():
                dark[k] = _h5RecurseLoad(v)
            return dark
        elif py_type == "list":
            return [_h5RecurseLoad(x) for x in g.values()]
        elif py_type == "tuple":
            return tuple(_h5RecurseLoad(x) for x in g.values())
        elif py_type == "set":
            return {_h5RecurseLoad(x) for x in g.values()}
        else:
            if py_type is None:
                raise ValueError("Cannot get py_type for group")
            dark = py_type.split(".")
            class_str = dark[-1]
            module = ".".join(dark[:-1])
            mod = import_module(module)
            klass = getattr(mod, class_str)
            dark = klass.__new__(klass)
            name = g.attrs["DATA_NAME"]
            dark.__setstate__(_h5RecurseLoad(cast(h5py.Group | h5py.Dataset, g[name])))
            return dark
    else:
        # This is a dataset
        if py_type == "str":
            return g[()].decode("utf-8")
        elif py_type == "None":
            return None
        elif py_type == "Quantity":
            return pq.Quantity(g[()], g.attrs["UNIT"])
        else:
            return g[()]


class _JsonRoundTrip:
    """Helper to serialize and deserialize custom objects.

    Register classes or types with this and then use the CustomEncoder and objectHook
    when using json.dumps and json.loads respectively to serialize and deserialize
    custom objects.
    """

    class CustomEncoder(JSONEncoder):
        """Serialize custom types for json.

        Register types with the extra_types and extra_types_tags variables.
        extra_types pairs a class type with a method that represents that class
        as a json serializable dictionary. Note, the dictionary can have other custom types
        and those will be handled recursively.
        extra_types_tags is a unique tag associated with each type.
        """

        _rt: "_JsonRoundTrip"
        extra_types: dict[Any, Callable[[Any], dict[str, Any]]] = {}
        extra_types_tags: dict[Any, str] = {}

        def default(self, o: Any) -> dict[str, Any] | None:
            """Serialize json objects.

            Overrides the default method and handles custom objects first. It will try to register
            objects it does not recognize if they have the to_json method defined. If the object is
            not a custom object, then just use the parent method.

            Parameters
            ----------
            o : Any
                The object to serialize.

            Returns
            -------
            dict[str,Any] | None
                The serialized object or None.
            """
            # Check custom objects first. Check them in reverse order, as the most specialized objects
            # are given last.
            v = self.extra_types.get(type(o), None)
            if v is not None:
                return v(o) | {"__type__": self.extra_types_tags[type(o)]}

            # This is a type we can register, so do that:
            if hasattr(o, "to_json"):
                cls = type(o)
                if not "json_tag" in cls.__dict__:
                    cls.json_tag = "!" + cls.__module__ + "." + cls.__qualname__
                self._rt.registerClass(cls)
                return cls.to_json(o) | {"__type__": self.extra_types_tags[type(o)]}

            # If this is not a custom object we know about, nor one we can register. Call the parent method.
            super().default(o)

    def __init__(self):
        """Create _JsonRoundTrip instance."""
        self.extra_types: dict[str, Callable[[dict[str, Any]], Any]] = {}
        self.CustomEncoder._rt = self

    def registerType[
        T
    ](
        self,
        klass: type[T],
        tag: str,
        serialize: Callable[[T], dict[str, Any]],
        deserialize: Callable[[dict[str, Any]], T],
    ) -> None:
        """Register a type T with the custom serializer and deserializer.

        Parameters
        ----------
        klass : T
            The type to register.
        tag : str
            The tag associated with the given type.
        serialize : Callable[[T], dict[str, Any]]
            The serialization function.
        deserialize : Callable[[dict[str,Any]], T]
            The derialization function.
        """
        self.CustomEncoder.extra_types[klass] = serialize
        self.CustomEncoder.extra_types_tags[klass] = tag
        self.extra_types[tag] = deserialize

    def registerClass(self, klass: Any):
        """Register a class with the custom serializer.

        The class must define a json_tag, which is a string,
        a to_json method, which is the serializer, and a from_json method,
        which is the deserializer.

        Parameters
        ----------
        klass : Any
            The class to register.
        """
        for atr in ["json_tag", "to_json", "from_json"]:
            if not hasattr(klass, atr):
                raise ValueError(f"Class {klass} is missing {atr}.")
        self.registerType(klass, klass.json_tag, klass.to_json, klass.from_json)

    def objectHook(self, d: dict[str, Any]) -> Any:
        """Deserialize function for json.loads.

        This uses the registered types during deserialization. If it encounters a type it does
        not recognize, it will try to dynamically register it.

        Parameters
        ----------
        d : dict[str,Any]
            The serialized object.

        Returns
        -------
        Any
            The deserialized object.
        """
        if "__type__" in d:
            t = d["__type__"]

            # Try looking into registered classes first
            for k, v in reversed(list(self.extra_types.items())):
                if t == k:
                    return v(d)

            # Try loading this type dynamically assuming __type__ is the class
            # If that fails, just try the default
            try:
                # Try importing the __type__ as a class and creating it
                dark = t[1:].split(".")
                class_str = dark[-1]
                module = ".".join(dark[:-1])
                mod = import_module(module)
                cls = getattr(mod, class_str)
                # We check cls.__dict__ here rather than hasattr, as we only want to check this version of the class.
                # We don't want to look into base classes. Otherwise, we will get potentially the wrong tag.
                if not "json_tag" in cls.__dict__:
                    cls.json_tag = t
                self.registerClass(cls)
                return cls.from_json(d)
            except:
                pass

        return d


[docs] class DataStruct(BaseModel): """Wrapper around `pydantic.BaseModel` that adds functionality useful for modeling and simulation. This class adds functionality to the `pydantic.BaseModel`, including: * Easily save/load `DataStruct`s from various file types. * Compare `DataStruct`s with one another either identitcally, using `__eq__`, or approximately, using `isApprox`. * Generate an asciidoc with the data of a `DataStruct`. The saving/loading and comparison can be done recusivley, meaning they accept `DataStruct`s whose fields are other `DataStructs`. In addition, they supported nested Python types, such as a dict of lists of numpy arrays. Parameters ---------- version : tuple[int, int] Holds the version of this DataStruct. Users should override the current version using the _version_default class variable. DataStructs that use this should also add a field validtor or model validator to handle version mismatches. """ # Define a ClassVar with the default for version. # This allows users to easily override it. _version_default: ClassVar[tuple[int, int]] = (0, 0) # Make this strict=False so pydantic will coerce a list to a tuple. version: tuple[int, int] = Field(default=_version_default, strict=False) # Enable strict mode to avoid type coercion model_config = ConfigDict(strict=True, arbitrary_types_allowed=True, validate_assignment=True) # Holds the DSLinkers if they exists for this class _ds_linker: ClassVar[Optional["DSLinker"]] = None def _migrate_experimental(self, deep=True) -> Self: """Create a copy, filling in missing fields with default values. This may be useful in cases where the DataStruct was instantiated without going through the constructor such as when unpickling. Parameters ---------- deep: bool = True If true, does a deep copy Returns ------- Self The copy """ if deep: return self.model_validate(self.model_dump(), strict=False) return self.model_validate(self.__dict__, strict=False)
[docs] @classmethod def generateAsciidoc(cls) -> str: """Generate asciidoc for the DataStruct.""" # Also, get the docstrings. We parse this information into a dictionary where the keys are # the parameter names and the values are the description. if cls.__doc__ is not None: doc_string_params = { name: "\n".join(desc) for name, _, desc in NumpyDocString(cls.__doc__).get("Parameters", ()) } else: doc_string_params = {} # Write a title and table header. title = f"=== {cls.__name__} ===\n\n" fields_table = "|===\n|Field |Type |Required |Description| Default\n" # Loop through each parameter and add its info to the table for name, field_info in cls.model_fields.items(): # Get the field annotation field_type = field_info.annotation # For common base types, convert to the name. if field_type is int: field_type = "int" elif field_type is str: field_type = "str" elif field_type is float: field_type = "float" required = "Yes" if field_info.is_required() else "No" # For the description, if the parameter has a description use that. # If it doesn't use the docstring information. # If it doesn't have either, then set its description to an empty string. description = field_info.description if description is None or description == "": description = doc_string_params.get(name, "") default = field_info.default if default == PydanticUndefined: default = "" fields_table += f"|{name} |{field_type} |{required} |{description} |{default}\n" fields_table += "|===\n" return title + fields_table
[docs] @classmethod def cli(cls, exclude: list[str] = [], extra_info: dict[str, Any] = {}): """Add this DataStruct's fields to a cli. Parameters ---------- exclude : list[str] List of fields to exclude from the cli. extra_info : dict[str, Any] Extra information that is used to generate the cli. Examples include: * name - Name the field. * help - Description for the field. * type - Type for the field. * default - Default for the field. """ if cls._ds_linker is None: cls._ds_linker = DSLinker(cls, exclude, extra_info) else: if not (cls._ds_linker.exclude == exclude and cls._ds_linker.extra_info == extra_info): raise ValueError( "DSLinker already exists with a different exclude and extra info. Need to create DSLinkers manually." ) return cls._ds_linker.toCli()
[docs] @classmethod def fromLinker( cls, vals_orig: dict[str, Any], excluded_vals: dict[str, Any] = {}, ds_linker: Optional["DSLinker"] = None, ) -> Self: """Create an instance of the DataStruct from a CLI dictionary. Parameters ---------- vals_orig : dict[str, Any] CLI dictionary to use to build the DataStruct. excluded_vals : dict[str, Any] Included any vals that were excluded from the DictLinker. ds_linker : Optional[DSLinker] DSLinker to use to build the DataStruct. This is optional, and is only necessary if there are multiple DSLinker's for a DataStruct. """ if ds_linker is None: ds_linker = cls._ds_linker if ds_linker is None: raise ValueError(f"DSLinker does not exist for {cls}.") if not ds_linker.name_link: raise ValueError(f"DSLinker has not been run for {cls}.") vals = deepcopy(vals_orig) | excluded_vals keys = list(vals.keys()) for k in keys: if k in ds_linker.exclude: continue k_new = ds_linker.name_link["--" + k.replace("_", "-")] vals[k_new] = vals.pop(k) return cls(**vals)
@classmethod def _ruamelYaml(cls) -> YAML: """Get the YAML class for this DataStruct. This contains a dynamic loader and dumper that will try to automatically register classes that have a to_yaml or from_yaml method defined. Returns ------- YAML The YAML class for this DataStruct. """ from ruamel.yaml import CommentedMap yaml = YAML() # Set flow style to None so that lists get represented as [] yaml.default_flow_style = None # Add LazyRepresenter. This will dynamically register classes for dumping if they have the # to_yaml method defined. A yaml_tag will be given if the class does not already define it. class LazyRepresenter(RoundTripRepresenter): def represent_data(self, data): cls = type(data) if cls not in self.yaml_representers: if hasattr(cls, "to_yaml"): # We check cls.__dict__ here rather than hasattr, as we only want to check this version of the class. # We don't want to look into base classes. Otherwise, we will get potentially the wrong tag. if not "yaml_tag" in cls.__dict__: cls.yaml_tag = "!" + cls.__module__ + "." + cls.__qualname__ yaml.register_class(cls) return super().represent_data(data) yaml.Representer = LazyRepresenter # Add LazyConstructor. This will try to dynamically register classes it does not recognize. def dynamic_constructor(loader, node): # This handles any tag we haven't seen before tag_suffix = node.tag[1:] try: dark = tag_suffix.split(".") class_str = dark[-1] module = ".".join(dark[:-1]) mod = import_module(module) cls = getattr(mod, class_str) # We check cls.__dict__ here rather than hasattr, as we only want to check this version of the class. # We don't want to look into base classes. Otherwise, we will get potentially the wrong tag. if not "yaml_tag" in cls.__dict__: cls.yaml_tag = node.tag yaml.register_class(cls) return cls.from_yaml(loader, node) except: raise ValueError(f"Cannot dynamically load yaml constructor for {tag_suffix}") yaml.Constructor.add_constructor(None, dynamic_constructor) # Representer for numpy arrays NumpyArrayTag = "!numpy_array" # Custom representer for numpy arrays def numpy_array_representer(dumper, data): # Convert the NumPy array to a list return dumper.represent_mapping( NumpyArrayTag, {"dtype": str(data.dtype), "shape": data.shape, "data": data.tolist()}, ) # Custom constructor for numpy arrays def numpy_array_constructor(loader, node): # Get the mapping from the YAML node values = CommentedMap() loader.construct_mapping(node, values, deep=True) # Recreate the NumPy array return np.array(values["data"], dtype=values["dtype"]).reshape(values["shape"]) # Register the representer and constructor with ruamel.yaml yaml.register_class( type("NumpyArrayWrapper", (), {}) ) # Avoid default serialization for NumPy array subclasses yaml.representer.add_representer(np.ndarray, numpy_array_representer) yaml.constructor.add_constructor(NumpyArrayTag, numpy_array_constructor) # Representer for quantities QuantityTag = "!quantity" # Custom representer for pq.Quantity def quantityRepresenter(dumper, data): return dumper.represent_mapping( QuantityTag, { "units": str(data.units).split()[1], "dtype": str(data.dtype), "shape": data.shape, "data": data.magnitude.tolist(), }, ) # Custom constructor for numpy arrays def quantityConstructor(loader, node): # Get the mapping from the YAML node values = CommentedMap() loader.construct_mapping(node, values, deep=True) # Recreate the Quantity return pq.Quantity( np.array(values["data"], dtype=values["dtype"]).reshape(values["shape"]), values["units"], ) # Register the representer and constructor with ruamel.yaml yaml.register_class( type("QuantityWrapper", (), {}) ) # Avoid default serialization for Quantity subclasses yaml.representer.add_representer(pq.Quantity, quantityRepresenter) yaml.constructor.add_constructor(QuantityTag, quantityConstructor) return yaml @classmethod def _jsonRoundTrip(cls) -> _JsonRoundTrip: """Get the _JsonRoundTrip class for this DataStruct. This contains a dynamic loader and dumper that will try to automatically register classes that have a to_json or from_json method defined. Returns ------- _JsonRoundTrip The _JsonRoundTrip class for this DataStruct. """ jrt = _JsonRoundTrip() # Representer for numpy arrays numpy_array_tag = "!numpy_array" # Custom representer for numpy arrays def numpyArrayDump(o: NDArray): return {"dtype": str(o.dtype), "shape": o.shape, "data": o.tolist()} # Custom constructor for numpy arrays def numpyArrayLoad(d: dict[str, Any]): return np.array(d["data"], dtype=d["dtype"]).reshape(d["shape"]) jrt.registerType(np.ndarray, numpy_array_tag, numpyArrayDump, numpyArrayLoad) # Representer for quantities quantity_tag = "!quantity" # Custom representer for pq.Quantity def quantityDump(o: pq.Quantity) -> dict[str, Any]: return { "units": str(o.units).split()[1], "dtype": str(o.dtype), "shape": o.shape, "data": o.magnitude.tolist(), } # Custom constructor for pq.Quantity def quantityLoad(d: dict[str, Any]): # Recreate the Quantity return pq.Quantity( np.array(d["data"], dtype=d["dtype"]).reshape(d["shape"]), d["units"], ) jrt.registerType(pq.Quantity, quantity_tag, quantityDump, quantityLoad) return jrt @overload def toFile( self, file: Path | str | IO[bytes], suffix: Optional[ Literal[".json", ".yaml", ".yml", ".h5", ".hdf5", ".pickle", ".pck", ".pcl"] ] = None, ) -> None: """Write the DataStruct to a file. Parameters ---------- file : Path | str | IO[bytes] Specify the file to write to. For each type, the suffix of the filename is used to determine the type of file to write to, e.g., ".yaml" for a YAML file or ".h5" for an HDF5 file, if the optional suffix argument is not specified. See below for details. In the case of IO[bytes], if the IO object has a name attr, that will be used to lookup the suffix. suffix : Optional[Literal[".json",".yaml",".yml",".h5",".hdf5",".pickle",".pck",".pcl"]] This optional keyword argument is used to set the file type. If it is not specified, then then file type is inferred from the name. If it can't be inferred from the name, then YAML is used. """ ... # pragma: no cover. Will not be run by code coverage. @overload def toFile(self, g: h5py.Group) -> None: """Write the DataStruct to a H5 group. Parameters ---------- g : h5py.Group The group to write to. """ ... # pragma: no cover. Will not be run by code coverage.
[docs] def toFile(self, *args, **kwargs): """Write the DataStruct to a H5 group. See overloads for deatils. """ if isinstance(group := args[0], h5py.Group): _h5RecurseSave(group, None, self) else: # Figure out the suffix file = args[0] suffix = kwargs.get("suffix", None) if suffix is None: if isinstance(file, Path): suffix = file.suffix elif isinstance(file, str): suffix = "." + file.split(".")[-1].strip() elif hasattr(file, "name"): suffix = "." + file.name.split(".")[-1].strip() else: warn("Cannot determine suffix from {file}. Using yaml.") suffix = ".yaml" def _write(f, suffix): if suffix == ".json": rt = self._jsonRoundTrip() f.write(dumps(self.model_dump(), cls=rt.CustomEncoder)) elif suffix in [".pickle", ".pck", ".pcl"]: import pickle pickle.dump(self, f) elif suffix in [".hdf5", ".h5"]: _h5RecurseSave(f, None, self) else: yaml = self._ruamelYaml() yaml.dump(self.model_dump(), f) if isinstance(file, Path) or isinstance(file, str): # If this is a string or path, we need to open the file to write it. # Do so in a context to ensure it is closed. if suffix in [".hdf5", ".h5"]: with h5py.File(file, "w") as f: _write(f, suffix) elif suffix in [".pickle", ".pck", ".pcl"]: with open(file, "wb") as f: _write(f, suffix) else: with open(file, "w") as f: _write(f, suffix) else: # This is an already open file. Just write to it. _write(file, suffix)
@classmethod @overload def fromFile( cls, file: Path | str | IO[bytes], suffix: Optional[ Literal[".json", ".yaml", ".yml", ".h5", ".hdf5", ".pickle", ".pck", ".pcl"] ] = None, ) -> Self: """Create an instance of this DataStruct from a file. Parameters ---------- file : Path | str | IO[bytes] Specify the file to read from. For each type, the suffix of the filename is used to determine the type of file to read from, e.g., ".yaml" for a YAML file or ".h5" for an HDF5 file, if the optional suffix argument is not specified. See below for details. In the case of IO[bytes], if the IO object has a name attr, that will be used to lookup the suffix. suffix : Optional[Literal[".json",".yaml",".yml",".h5",".hdf5",".pickle",".pck",".pcl"]] This optional keyword argument is used to set the file type. If it is not specified, then then file type is inferred from the name. If it can't be inferred from the name, then YAML is used. Returns ------- Self Instance of this DataStruct. """ ... # pragma: no cover. Will not be run by code coverage. @classmethod @overload def fromFile(cls, g: h5py.Group) -> Self: """Create an instance of this DataStruct from a h5py.Group. Parameters ---------- g : h5py.Group The group to write to. Returns ------- Self Instance of this DataStruct. """ ... # pragma: no cover. Will not be run by code coverage.
[docs] @classmethod def fromFile(cls, *args, **kwargs) -> Self: """Create an instance of this DataStruct a file. See overloads for details. """ if isinstance(group := args[0], h5py.Group): return cast(Self, _h5RecurseLoad(group)) else: # Figure out the suffix file = args[0] suffix = kwargs.get("suffix", None) if suffix is None: if isinstance(file, Path): suffix = file.suffix elif isinstance(file, str): suffix = "." + file.split(".")[-1].strip() elif hasattr(file, "name"): suffix = "." + file.name.split(".")[-1].strip() else: warn("Cannot determine suffix from {file}. Using yaml.") suffix = ".yaml" def _read(f, suffix): if suffix == ".json": rt = cls._jsonRoundTrip() return load(f, object_hook=rt.objectHook) elif suffix in [".pickle", ".pck", ".pcl"]: import pickle return pickle.load(f) else: yaml = cls._ruamelYaml() return yaml.load(f) if isinstance(file, Path) or isinstance(file, str): # If this is a string or path, we need to open the file to write it. # Do so in a context to ensure it is closed. if suffix in [".hdf5", ".h5"]: with h5py.File(file, "r") as f: ret = cast(Self, _h5RecurseLoad(f)) return ret elif suffix in [".pickle", ".pck", ".pcl"]: with open(file, "rb") as f: return _read(f, suffix) else: with open(file, "r") as f: from_dict = _read(f, suffix) from_dict = cast(dict[str, Any], from_dict) return cls(**from_dict) else: # Otherwise, this is an IO object that is already open. Just read and close. # No need to handle H5 files here, because they are also H5Groups, so will be # handled above. if suffix in [".pickle", ".pck", ".pcl"]: return _read(file, suffix) else: from_dict = _read(file, suffix) from_dict = cast(dict[str, Any], from_dict) return cls(**from_dict)
[docs] def isApprox(self, other: Self, prec: float = MATH_EPSILON) -> bool: """Check if this DataStruct is approximately equal to another DataStruct of the same type. This recursively moves through the public fields only. Note that Pydantic's __eq__ checks private and public fields. Recursivley here means if the field is an iterator, another DataStruct, an interator of DataStructs, etc. we will go into the nested structure calling isApprox where appropriate on the items in the iterator, fields in the DataStruct, etc. If the field (or iterated value, etc.) does not have an isApprox method, then we fallback to using the __eq__. If all values of isApprox (or __eq__) return True, then this returns True; otherwise, this returns false. Parameters ---------- other : Self A DataStruct of the same type. prec : float The precision to use. Karana.Math.MATH_EPSILON is the default. Returns ------- bool True if the two DataStructures are approximately equal. False otherwise. """ def _isApproxRecurse(val1, val2, prec: float): # Call isApprox if it exists on the item. Otherwise, call the == method. # Loop through different iterator types recursively as appropriate if hasattr(val1, "isApprox"): return val1.isApprox(val2, prec=prec) elif isinstance(val1, (tuple, list)): return all(_isApproxRecurse(x, y, prec) for x, y in zip(val1, val2)) elif isinstance(val1, dict): if len(val1) != len(val2): return False else: return all(_isApproxRecurse(v, val2[k], prec) for k, v in val1.items()) elif isinstance(val1, pq.Quantity): # If these are quantities, then just look at the raw values. Leaving the units # in would result in errors as this compares with things like # atol + rtol * abs(y), which would change the units to be units**2 for # rtol * abs(y) return np.allclose(val1.magnitude, val2.magnitude, rtol=prec, atol=prec) elif isinstance(val1, np.ndarray): return np.allclose(val1, val2, rtol=prec, atol=prec) else: return val1 == val2 for field, v in iter(self): if not _isApproxRecurse(v, getattr(other, field), prec): return False return True
[docs] def __eq__(self, other: object) -> bool: """Check if this DataStruct is equal to another DataStruct of the same type. First a normal == is tried on the fields. If that doesn't work, then we recurse into the fields looking for numpy arrays. Recursivley here means if the field is an iterator, another DataStruct, an interator of DataStructs, etc. we will go into the nested structure calling the appropriate operator on the items in the iterator, fields in the DataStruct, etc. This is done mainly for numpy arrays, where we want to call np.array_equal rather than use ==. Parameters ---------- other : object The object to compare with. Returns ------- bool True if the two DataStructures are equal. False otherwise. """ try: # First, just try the normal operator. This will work in cases # where no fields have a numpy array return super().__eq__(other) except: def _eqRecurse(val1, val2): """Loop through iterators, DataStructs, etc. to try and find numpy arrays. Call np.array_equal rather than using __eq__ """ if isinstance(val1, np.ndarray): return np.array_equal(val1, val2) elif isinstance(val1, (tuple, list)): return all(_eqRecurse(x, y) for x, y in zip(val1, val2)) elif isinstance(val1, dict): return all(_eqRecurse(v, val2[k]) for k, v in val1.items()) else: return val1 == val2 # Now, loop through the fields one by one. for field, v in iter(self): v2 = getattr(other, field) try: # First try the normal equality operator. if not v == v2: return False except: # If that fails, then recurse through looking through numpy arrays if not _eqRecurse(v, v2): return False # Check private fields too if not _eqRecurse( getattr(self, "__pydantic_private__", None), getattr(other, "__pydantic_private__", None), ): return False return True
# Implement a simple singleton that will serve as a Sentinal
[docs] class SentinalValueClass(object): """Dummy class uses as a sentinal value.""" _instance = None def __new__(cls, *args, **kwargs): """Create a new instance of the SentinalValueClass.""" if not cls._instance: cls._instance = super(SentinalValueClass, cls).__new__(cls, *args, **kwargs) return cls._instance
SentinalValue = SentinalValueClass() # Dict linker class for linking the cli and input dictionaries
[docs] class DSLinker: """Class used to populate a CLI application with `DataStruct` fields and create a `DataStruct` instance from CLI options.""" def __init__( self, data_struct: type[DataStruct], exclude: list[str], extra_info: dict[str, Any] ): """Create an DSLinker instance. DSLinkers are used to link a DataStruct with a Kclick CLI. Parameters ---------- exclude : list[str] List of fields to exclude from the cli. extra_info : dict[str, Any] Extra information that is used to generate the cli. Examples include: * name - Name the field. * help - Description for the field. * type - Type for the field. * default - Default for the field. """ self.data_struct = data_struct self.exclude = exclude self.extra_info = extra_info # This links the CLI name to the DataStruct name. self.name_link: dict[str, str] = {}
[docs] def toCli(self): """Use to convert a DataStruct to a Kclick cli.""" # Also, get the docstrings. We parse this information into a dictionary where the keys are # the parameter names and the values are the description. if self.data_struct.__doc__ is not None: doc_string_params = { name: "\n".join(desc) for name, _, desc in NumpyDocString(self.data_struct.__doc__)["Parameters"] } else: doc_string_params = {} def decorator(function): for name, field_info in self.data_struct.model_fields.items(): # Skip if this is part of the exclude list if name in self.exclude: continue # Get parameter information and add a click option to the function p_extra_info = self.extra_info.get(name, {}) p_type = p_extra_info.get("type", None) if p_type is None: p_type = field_info.annotation p_default = p_extra_info.get("default", None) if p_default is None: p_default = field_info.default if p_default == PydanticUndefined: p_default = None # For the description, if the parameter has a description use that. # If it doesn't use the docstring information. # If it doesn't have either, then set its description to an empty string. p_help = self.extra_info.get("help", None) if p_help is None: p_help = field_info.description if p_help is None: p_help = doc_string_params.get(name, "") p_name = self.extra_info.get("name", None) if p_name is None: p_name = "--" + name.replace("_", "-") self.name_link[p_name] = name function = click.option( p_name, type=p_type, default=p_default, help=p_help, required=field_info.is_required(), )(function) return function return decorator
T = TypeVar("T", bound=Base) # Pydantic models have issues with nesting when __init__ is overridden or dump_model is # overridden. The easiest way around this, in this case, was to just make _id and # _objects_from_id private attributes. These are normally not serialized/deserialized, # so model_serializers and model_validators have been added.
[docs] class IdMixin(BaseModel, Generic[T]): """Mixin to add ID tracking to a DataStruct. For book keeping, it is common to want to track the id of the object used to create the DataStruct. This mixin makes it easy to do so. It adds the private _id variable with default value None, adds a KaranaId property for it, and ovverrides the appropriate methods so that _id is serialized/deserialized if set. It is class using the Mixin's job to add objects to _objects_from_id whenever appropriate. """ _id: int | None = PrivateAttr(default=None) _objects_from_id: list[CppWeakRef[T]] = PrivateAttr(default=[]) @property def karanaId(self) -> int | None: """Retrieve the private _id variable.""" return self._id @model_serializer(mode="wrap") def _serialize_id_field(self, serializer: Callable[[Any], dict[str, Any]]) -> dict[str, Any]: """Perform normal serialization, but add the _id field. Parameters ---------- serializer : Callable[[Any], dict[str, Any]] Normal serializer. Returns ------- dict[str, Any] A serialized version of the class, with _id added. """ data = serializer(self) if self._id is not None: data["_id"] = self._id return data @model_validator(mode="wrap") @classmethod def _add_id_to_instance(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) -> Self: """Perform normal validation, but add _id. Parameters ---------- data : Any The data for normal validation. handler : ModelWrapValidatorHandler[Self] The normal validation function. Returns ------- Self An instance of Self. Validated as normal, but with _id added. """ inst = handler(data) if isinstance(data, dict) and "_id" in data: inst._id = data["_id"] return inst
[docs] def __getstate__(self): """Override __getstate__ to include _id if set.""" state = super().__getstate__() if self._id is not None: state["__dict__"]["_id"] = self._id return state
[docs] @staticmethod def findObjectsCreatedById(val: Any, id: int) -> list[T] | None: """Find any objects in the val data structure that were created for the ID given by id. This assumes unique IDs (and a unique DataStruct with that ID). Therefore, the search stops at the first ID. Parameters ---------- val : Any The data structure to recurse through. This can be a composite type consisting of DataStructs, lists, tuples, sets, and dictionaries. id : int The ID to use in the search. Returns ------- list[T] | None None if the ID was not found. Otherwise, a list of objects created with the ID. """ def _recurse(val: Any, id=id) -> tuple[list[T], bool]: found_id = False if isinstance(val, dict): for k, v in val.items(): dark, found_id = _recurse(k) if found_id: return dark, found_id dark, found_id = _recurse(v) if found_id: return dark, found_id elif isinstance(val, list) or isinstance(val, tuple) or isinstance(val, set): for v in val: dark, found_id = _recurse(v) if found_id: return dark, found_id elif isinstance(val, IdMixin): # We have found the ID. Loop through all the weakrefs. Throw a # warning if any are not still around. Return those that are. if val._id == id: dark: list[T] = [] flag = False for obj in val._objects_from_id: ob = cast(T | None, obj()) if ob is not None: dark.append(ob) else: flag = True if flag: warn( "IdMixin found id, but weakrefs returned None. Some weakrefs have been deleted." ) return dark, True else: for _, v in iter(val): dark, found_id = _recurse(v) if found_id: return dark, found_id elif isinstance(val, DataStruct): for _, v in iter(val): dark, found_id = _recurse(v) if found_id: return dark, found_id return [], False objs, found_id = _recurse(val) if found_id: return objs else: return None
[docs] def NestedBaseMixin(field_name: str): """Create a NestedBaseMixin class. This method returns a class that is a mixin. The mixin is used for classes that will be nested in other Pydantic models, and that serve as a base class for other Pydantic classes. An example of this is KModelDS, which is used in StatePropagatorDS. Other types, e.g., PointGravityModelDS, will be derived from this. When we serialize/deserialize, we want to save/load the derived type, not the base type. This should be used in conjuction with SerializeAsAny. A field will be added to keep track of the type. The name is given by field_name parameters. Parameters ---------- field_name : str The name of the field that is automatically added to the base class as part of the mixin and all derived classes. Returns ------- _NestedBaseMixin An instance of the mixin. """ class _NestedBaseMixin(BaseModel): @classmethod def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: """Add extra field to the subclass. For the base class and all derived classes, we want to add a field to keep track of the specific type, so we can serialize and deserialize correctly. """ super().__pydantic_init_subclass__(**kwargs) name = cls.__module__ + "." + cls.__qualname__ cls.model_fields[field_name] = FieldInfo(default=name, annotation=Literal[name]) cls.model_rebuild(force=True) @model_validator(mode="wrap") @classmethod def _useCorrectType(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) -> Self: """Use the correct type when deserializing. This uses the added field to get the most derived type and deserialize appropriately. This is provided that SerializeAsAny was used when serializing. Parameters ---------- data : Any The data to deserialize from. handler : ModelWrapValidatorHandler[Self] The handler used for deserialization of the base class. Returns ------- Self An instance of the most derived type. """ if field_name in data: if ( data[field_name] == cls.__pydantic_fields__.get(field_name, FieldInfo()).default ): # pyright: ignore - false positive return handler(data) else: dark = data[field_name].split(".") class_str = dark[-1] module = ".".join(dark[:-1]) mod = import_module(module) klass = getattr(mod, class_str) return klass(**data) else: return handler(data) return _NestedBaseMixin