# Copyright (c) 2024-2025 Karana Dynamics Pty Ltd. All rights reserved.
#
# NOTICE TO USER:
#
# This source code and/or documentation (the "Licensed Materials") is
# the confidential and proprietary information of Karana Dynamics Inc.
# Use of these Licensed Materials is governed by the terms and conditions
# of a separate software license agreement between Karana Dynamics and the
# Licensee ("License Agreement"). Unless expressly permitted under that
# agreement, any reproduction, modification, distribution, or disclosure
# of the Licensed Materials, in whole or in part, to any third party
# without the prior written consent of Karana Dynamics is strictly prohibited.
#
# THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND.
# KARANA DYNAMICS DISCLAIMS ALL WARRANTIES, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT, AND
# FITNESS FOR A PARTICULAR PURPOSE.
#
# IN NO EVENT SHALL KARANA DYNAMICS BE LIABLE FOR ANY DAMAGES WHATSOEVER,
# INCLUDING BUT NOT LIMITED TO LOSS OF PROFITS, DATA, OR USE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGES, WHETHER IN CONTRACT, TORT,
# OR OTHERWISE ARISING OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS.
#
# U.S. Government End Users: The Licensed Materials are a "commercial item"
# as defined at 48 C.F.R. 2.101, and are provided to the U.S. Government
# only as a commercial end item under the terms of this license.
#
# Any use of the Licensed Materials in individual or commercial software must
# include, in the user documentation and internal source code comments,
# this Notice, Disclaimer, and U.S. Government Use Provision.
"""Classes and functions related to DataStruct.
DataStruct is a wrapper around `pydantic.BaseModel`. It contains extra functionality to:
* Easily save/load `DataStruct`s from various file types.
* Compare `DataStruct`s with one another either identitcally, using `__eq__`, or
approximately, using `isApprox`.
* Populate a Kclick CLI applications options from a `DataStruct`
* Create an instance of a `DataStruct` from Kclick CLI application's options.
* Generate an asciidoc with the data of a `DataStruct`.
The saving/loading and comparison can be done recusivley, meaning they accept `DataStruct`s
whose fields are other `DataStructs`. In addition, they supported nested Python types, such as
a dict of lists of numpy arrays.
"""
import h5py
import numpy as np
from numpy.typing import NDArray
import quantities as pq
from importlib import import_module
from inspect import getmodule
from pydantic import (
BaseModel,
ConfigDict,
model_serializer,
PrivateAttr,
model_validator,
ModelWrapValidatorHandler,
Field,
)
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from numpydoc.docscrape import NumpyDocString
from typing import (
Any,
Optional,
ClassVar,
Self,
TypeVar,
cast,
overload,
IO,
Literal,
Callable,
Generic,
)
from pathlib import Path
import click
from copy import deepcopy
from ruamel.yaml import YAML, RoundTripRepresenter
from json import dumps, JSONEncoder, load
from Karana.Math import MATH_EPSILON
from Karana.Core import warn, Base, CppWeakRef
def _h5RecurseSave(pg: h5py.Group, name: Optional[str], data: Any):
"""Recursively save some data to HDF5.
Parameters
----------
pg : h5py.Group
The parent group to attach this data to.
name : Optional[str]
The name to call the data. This will be the name of
the child group or dataset, depending on the data type.
If it is None, then we want to write the data into the
current group we have.
data : Any
The data to save.
"""
if isinstance(data, dict):
if name is not None:
g = pg.create_group(name)
else:
g = pg
g.attrs["PY_TYPE"] = "dict"
for k, v in data.items():
_h5RecurseSave(g, k, v)
elif isinstance(data, list):
if name is not None:
g = pg.create_group(name)
else:
g = pg
g.attrs["PY_TYPE"] = "list"
g.attrs["NUM_ELEMENTS"] = len(data)
for k, d in enumerate(data):
name = f"element_{k}"
_h5RecurseSave(g, name, d)
elif isinstance(data, tuple):
if name is not None:
g = pg.create_group(name)
else:
g = pg
g.attrs["PY_TYPE"] = "tuple"
g.attrs["NUM_ELEMENTS"] = len(data)
for k, d in enumerate(data):
name = f"element_{k}"
_h5RecurseSave(g, name, d)
elif isinstance(data, set):
if name is not None:
g = pg.create_group(name)
else:
g = pg
g.attrs["PY_TYPE"] = "set"
g.attrs["NUM_ELEMENTS"] = len(data)
for k, d in enumerate(data):
name = f"element_{k}"
_h5RecurseSave(g, name, d)
elif isinstance(data, pq.Quantity):
dataset = pg.create_dataset(name, data=data)
dataset.attrs["UNIT"] = str(data.units).split()[1]
dataset.attrs["PY_TYPE"] = "Quantity"
elif isinstance(data, np.ndarray):
dataset = pg.create_dataset(name, data=data)
dataset.attrs["PY_TYPE"] = "np.array"
elif hasattr(data, "__getstate__") and hasattr(data, "__setstate__"):
# If this is a picklealbe object, then we use the information there.
if name is not None:
g = pg.create_group(name)
else:
g = pg
mod = getmodule(data.__class__)
if mod is None:
raise ValueError("Issues getting fully qualified string for object.")
name = cast(str, data.__class__.__name__) # None is not possible, as this is a class
py_type = mod.__name__ + "." + name
if hasattr(data, "name"):
if callable(data.name):
name = cast(str, data.name())
else:
name = data.name
else:
name = py_type
g.attrs["PY_TYPE"] = py_type
g.attrs["DATA_NAME"] = name
_h5RecurseSave(g, name, data.__getstate__())
elif isinstance(data, str):
# We give this special attention, as string will reload as bytes, so we need to
# convert it.
dataset = pg.create_dataset(name, data=data)
dataset.attrs["PY_TYPE"] = "str"
elif data is None:
dataset = pg.create_dataset(name, data="PY_NONE")
dataset.attrs["PY_TYPE"] = "None"
else:
# This is just a normal python object. Save it to a dataset.
dataset = pg.create_dataset(name, data=data)
def _h5RecurseLoad(g: h5py.Group | h5py.Dataset):
"""Recursively load some data to HDF5.
Parameters
----------
g : h5py.Group | h5py.Dataset
The group or dataset to load.
"""
py_type = g.attrs.get("PY_TYPE", None)
if isinstance(g, h5py.Group):
if py_type == "dict":
dark = {}
for k, v in g.items():
dark[k] = _h5RecurseLoad(v)
return dark
elif py_type == "list":
return [_h5RecurseLoad(x) for x in g.values()]
elif py_type == "tuple":
return tuple(_h5RecurseLoad(x) for x in g.values())
elif py_type == "set":
return {_h5RecurseLoad(x) for x in g.values()}
else:
if py_type is None:
raise ValueError("Cannot get py_type for group")
dark = py_type.split(".")
class_str = dark[-1]
module = ".".join(dark[:-1])
mod = import_module(module)
klass = getattr(mod, class_str)
dark = klass.__new__(klass)
name = g.attrs["DATA_NAME"]
dark.__setstate__(_h5RecurseLoad(cast(h5py.Group | h5py.Dataset, g[name])))
return dark
else:
# This is a dataset
if py_type == "str":
return g[()].decode("utf-8")
elif py_type == "None":
return None
elif py_type == "Quantity":
return pq.Quantity(g[()], g.attrs["UNIT"])
else:
return g[()]
class _JsonRoundTrip:
"""Helper to serialize and deserialize custom objects.
Register classes or types with this and then use the CustomEncoder and objectHook
when using json.dumps and json.loads respectively to serialize and deserialize
custom objects.
"""
class CustomEncoder(JSONEncoder):
"""Serialize custom types for json.
Register types with the extra_types and extra_types_tags variables.
extra_types pairs a class type with a method that represents that class
as a json serializable dictionary. Note, the dictionary can have other custom types
and those will be handled recursively.
extra_types_tags is a unique tag associated with each type.
"""
_rt: "_JsonRoundTrip"
extra_types: dict[Any, Callable[[Any], dict[str, Any]]] = {}
extra_types_tags: dict[Any, str] = {}
def default(self, o: Any) -> dict[str, Any] | None:
"""Serialize json objects.
Overrides the default method and handles custom objects first. It will try to register
objects it does not recognize if they have the to_json method defined. If the object is
not a custom object, then just use the parent method.
Parameters
----------
o : Any
The object to serialize.
Returns
-------
dict[str,Any] | None
The serialized object or None.
"""
# Check custom objects first. Check them in reverse order, as the most specialized objects
# are given last.
v = self.extra_types.get(type(o), None)
if v is not None:
return v(o) | {"__type__": self.extra_types_tags[type(o)]}
# This is a type we can register, so do that:
if hasattr(o, "to_json"):
cls = type(o)
if not "json_tag" in cls.__dict__:
cls.json_tag = "!" + cls.__module__ + "." + cls.__qualname__
self._rt.registerClass(cls)
return cls.to_json(o) | {"__type__": self.extra_types_tags[type(o)]}
# If this is not a custom object we know about, nor one we can register. Call the parent method.
super().default(o)
def __init__(self):
"""Create _JsonRoundTrip instance."""
self.extra_types: dict[str, Callable[[dict[str, Any]], Any]] = {}
self.CustomEncoder._rt = self
def registerType[
T
](
self,
klass: type[T],
tag: str,
serialize: Callable[[T], dict[str, Any]],
deserialize: Callable[[dict[str, Any]], T],
) -> None:
"""Register a type T with the custom serializer and deserializer.
Parameters
----------
klass : T
The type to register.
tag : str
The tag associated with the given type.
serialize : Callable[[T], dict[str, Any]]
The serialization function.
deserialize : Callable[[dict[str,Any]], T]
The derialization function.
"""
self.CustomEncoder.extra_types[klass] = serialize
self.CustomEncoder.extra_types_tags[klass] = tag
self.extra_types[tag] = deserialize
def registerClass(self, klass: Any):
"""Register a class with the custom serializer.
The class must define a json_tag, which is a string,
a to_json method, which is the serializer, and a from_json method,
which is the deserializer.
Parameters
----------
klass : Any
The class to register.
"""
for atr in ["json_tag", "to_json", "from_json"]:
if not hasattr(klass, atr):
raise ValueError(f"Class {klass} is missing {atr}.")
self.registerType(klass, klass.json_tag, klass.to_json, klass.from_json)
def objectHook(self, d: dict[str, Any]) -> Any:
"""Deserialize function for json.loads.
This uses the registered types during deserialization. If it encounters a type it does
not recognize, it will try to dynamically register it.
Parameters
----------
d : dict[str,Any]
The serialized object.
Returns
-------
Any
The deserialized object.
"""
if "__type__" in d:
t = d["__type__"]
# Try looking into registered classes first
for k, v in reversed(list(self.extra_types.items())):
if t == k:
return v(d)
# Try loading this type dynamically assuming __type__ is the class
# If that fails, just try the default
try:
# Try importing the __type__ as a class and creating it
dark = t[1:].split(".")
class_str = dark[-1]
module = ".".join(dark[:-1])
mod = import_module(module)
cls = getattr(mod, class_str)
# We check cls.__dict__ here rather than hasattr, as we only want to check this version of the class.
# We don't want to look into base classes. Otherwise, we will get potentially the wrong tag.
if not "json_tag" in cls.__dict__:
cls.json_tag = t
self.registerClass(cls)
return cls.from_json(d)
except:
pass
return d
[docs]
class DataStruct(BaseModel):
"""Wrapper around `pydantic.BaseModel` that adds functionality useful for modeling and simulation.
This class adds functionality to the `pydantic.BaseModel`, including:
* Easily save/load `DataStruct`s from various file types.
* Compare `DataStruct`s with one another either identitcally, using `__eq__`, or
approximately, using `isApprox`.
* Generate an asciidoc with the data of a `DataStruct`.
The saving/loading and comparison can be done recusivley, meaning they accept `DataStruct`s
whose fields are other `DataStructs`. In addition, they supported nested Python types, such as
a dict of lists of numpy arrays.
Parameters
----------
version : tuple[int, int]
Holds the version of this DataStruct. Users should override the current version using the
_version_default class variable. DataStructs that use this should also add a field validtor or
model validator to handle version mismatches.
"""
# Define a ClassVar with the default for version.
# This allows users to easily override it.
_version_default: ClassVar[tuple[int, int]] = (0, 0)
# Make this strict=False so pydantic will coerce a list to a tuple.
version: tuple[int, int] = Field(default=_version_default, strict=False)
# Enable strict mode to avoid type coercion
model_config = ConfigDict(strict=True, arbitrary_types_allowed=True, validate_assignment=True)
# Holds the DSLinkers if they exists for this class
_ds_linker: ClassVar[Optional["DSLinker"]] = None
def _migrateExperimental(self, deep=True) -> Self:
"""Create a copy, filling in missing fields with default values.
This may be useful in cases where the DataStruct was
instantiated without going through the constructor such as when
unpickling.
Parameters
----------
deep: bool = True
If true, does a deep copy
Returns
-------
Self
The copy
"""
if deep:
return self.model_validate(self.model_dump(), strict=False)
return self.model_validate(self.__dict__, strict=False)
[docs]
@classmethod
def generateAsciidoc(cls) -> str:
"""Generate asciidoc for the DataStruct."""
# Also, get the docstrings. We parse this information into a dictionary where the keys are
# the parameter names and the values are the description.
if cls.__doc__ is not None:
doc_string_params = {
name: "\n".join(desc)
for name, _, desc in NumpyDocString(cls.__doc__).get("Parameters", ())
}
else:
doc_string_params = {}
# Write a title and table header.
title = f"=== {cls.__name__} ===\n\n"
fields_table = "|===\n|Field |Type |Required |Description| Default\n"
# Loop through each parameter and add its info to the table
for name, field_info in cls.model_fields.items():
# Get the field annotation
field_type = field_info.annotation
# For common base types, convert to the name.
if field_type is int:
field_type = "int"
elif field_type is str:
field_type = "str"
elif field_type is float:
field_type = "float"
required = "Yes" if field_info.is_required() else "No"
# For the description, if the parameter has a description use that.
# If it doesn't use the docstring information.
# If it doesn't have either, then set its description to an empty string.
description = field_info.description
if description is None or description == "":
description = doc_string_params.get(name, "")
default = field_info.default
if default == PydanticUndefined:
default = ""
fields_table += f"|{name} |{field_type} |{required} |{description} |{default}\n"
fields_table += "|===\n"
return title + fields_table
[docs]
@classmethod
def cli(cls, exclude: list[str] = [], extra_info: dict[str, Any] = {}):
"""Add this DataStruct's fields to a cli.
Parameters
----------
exclude : list[str]
List of fields to exclude from the cli.
extra_info : dict[str, Any]
Extra information that is used to generate the cli. Examples include:
* name - Name the field.
* help - Description for the field.
* type - Type for the field.
* default - Default for the field.
"""
if cls._ds_linker is None:
cls._ds_linker = DSLinker(cls, exclude, extra_info)
else:
if not (cls._ds_linker.exclude == exclude and cls._ds_linker.extra_info == extra_info):
raise ValueError(
"DSLinker already exists with a different exclude and extra info. Need to create DSLinkers manually."
)
return cls._ds_linker.toCli()
[docs]
@classmethod
def fromLinker(
cls,
vals_orig: dict[str, Any],
excluded_vals: dict[str, Any] = {},
ds_linker: Optional["DSLinker"] = None,
) -> Self:
"""Create an instance of the DataStruct from a CLI dictionary.
Parameters
----------
vals_orig : dict[str, Any]
CLI dictionary to use to build the DataStruct.
excluded_vals : dict[str, Any]
Included any vals that were excluded from the DictLinker.
ds_linker : Optional[DSLinker]
DSLinker to use to build the DataStruct. This is optional, and is
only necessary if there are multiple DSLinker's for a DataStruct.
"""
if ds_linker is None:
ds_linker = cls._ds_linker
if ds_linker is None:
raise ValueError(f"DSLinker does not exist for {cls}.")
if not ds_linker.name_link:
raise ValueError(f"DSLinker has not been run for {cls}.")
vals = deepcopy(vals_orig) | excluded_vals
keys = list(vals.keys())
for k in keys:
if k in ds_linker.exclude:
continue
k_new = ds_linker.name_link["--" + k.replace("_", "-")]
vals[k_new] = vals.pop(k)
return cls(**vals)
@classmethod
def _ruamelYaml(cls) -> YAML:
"""Get the YAML class for this DataStruct.
This contains a dynamic loader and dumper that will try to automatically register
classes that have a to_yaml or from_yaml method defined.
Returns
-------
YAML
The YAML class for this DataStruct.
"""
from ruamel.yaml import CommentedMap
yaml = YAML()
# Set flow style to None so that lists get represented as []
yaml.default_flow_style = None
# Add LazyRepresenter. This will dynamically register classes for dumping if they have the
# to_yaml method defined. A yaml_tag will be given if the class does not already define it.
class LazyRepresenter(RoundTripRepresenter):
def represent_data(self, data):
cls = type(data)
if cls not in self.yaml_representers:
if hasattr(cls, "to_yaml"):
# We check cls.__dict__ here rather than hasattr, as we only want to check this version of the class.
# We don't want to look into base classes. Otherwise, we will get potentially the wrong tag.
if not "yaml_tag" in cls.__dict__:
cls.yaml_tag = "!" + cls.__module__ + "." + cls.__qualname__
yaml.register_class(cls)
return super().represent_data(data)
yaml.Representer = LazyRepresenter
# Add LazyConstructor. This will try to dynamically register classes it does not recognize.
def dynamicConstructor(loader, node):
# This handles any tag we haven't seen before
tag_suffix = node.tag[1:]
try:
dark = tag_suffix.split(".")
class_str = dark[-1]
module = ".".join(dark[:-1])
mod = import_module(module)
cls = getattr(mod, class_str)
# We check cls.__dict__ here rather than hasattr, as we only want to check this version of the class.
# We don't want to look into base classes. Otherwise, we will get potentially the wrong tag.
if not "yaml_tag" in cls.__dict__:
cls.yaml_tag = node.tag
yaml.register_class(cls)
return cls.from_yaml(loader, node)
except:
raise ValueError(f"Cannot dynamically load yaml constructor for {tag_suffix}")
yaml.Constructor.add_constructor(None, dynamicConstructor)
# Representer for numpy arrays
NumpyArrayTag = "!numpy_array"
# Custom representer for numpy arrays
def numpyArrayRepresenter(dumper, data):
# Convert the NumPy array to a list
return dumper.represent_mapping(
NumpyArrayTag,
{"dtype": str(data.dtype), "shape": data.shape, "data": data.tolist()},
)
# Custom constructor for numpy arrays
def numpyArrayConstructor(loader, node):
# Get the mapping from the YAML node
values = CommentedMap()
loader.construct_mapping(node, values, deep=True)
# Recreate the NumPy array
return np.array(values["data"], dtype=values["dtype"]).reshape(values["shape"])
# Register the representer and constructor with ruamel.yaml
yaml.register_class(
type("NumpyArrayWrapper", (), {})
) # Avoid default serialization for NumPy array subclasses
yaml.representer.add_representer(np.ndarray, numpyArrayRepresenter)
yaml.constructor.add_constructor(NumpyArrayTag, numpyArrayConstructor)
# Representer for quantities
QuantityTag = "!quantity"
# Custom representer for pq.Quantity
def quantityRepresenter(dumper, data):
return dumper.represent_mapping(
QuantityTag,
{
"units": str(data.units).split()[1],
"dtype": str(data.dtype),
"shape": data.shape,
"data": data.magnitude.tolist(),
},
)
# Custom constructor for numpy arrays
def quantityConstructor(loader, node):
# Get the mapping from the YAML node
values = CommentedMap()
loader.construct_mapping(node, values, deep=True)
# Recreate the Quantity
return pq.Quantity(
np.array(values["data"], dtype=values["dtype"]).reshape(values["shape"]),
values["units"],
)
# Register the representer and constructor with ruamel.yaml
yaml.register_class(
type("QuantityWrapper", (), {})
) # Avoid default serialization for Quantity subclasses
yaml.representer.add_representer(pq.Quantity, quantityRepresenter)
yaml.constructor.add_constructor(QuantityTag, quantityConstructor)
return yaml
@classmethod
def _jsonRoundTrip(cls) -> _JsonRoundTrip:
"""Get the _JsonRoundTrip class for this DataStruct.
This contains a dynamic loader and dumper that will try to automatically register
classes that have a to_json or from_json method defined.
Returns
-------
_JsonRoundTrip
The _JsonRoundTrip class for this DataStruct.
"""
jrt = _JsonRoundTrip()
# Representer for numpy arrays
numpy_array_tag = "!numpy_array"
# Custom representer for numpy arrays
def numpyArrayDump(o: NDArray):
return {"dtype": str(o.dtype), "shape": o.shape, "data": o.tolist()}
# Custom constructor for numpy arrays
def numpyArrayLoad(d: dict[str, Any]):
return np.array(d["data"], dtype=d["dtype"]).reshape(d["shape"])
jrt.registerType(np.ndarray, numpy_array_tag, numpyArrayDump, numpyArrayLoad)
# Representer for quantities
quantity_tag = "!quantity"
# Custom representer for pq.Quantity
def quantityDump(o: pq.Quantity) -> dict[str, Any]:
return {
"units": str(o.units).split()[1],
"dtype": str(o.dtype),
"shape": o.shape,
"data": o.magnitude.tolist(),
}
# Custom constructor for pq.Quantity
def quantityLoad(d: dict[str, Any]):
# Recreate the Quantity
return pq.Quantity(
np.array(d["data"], dtype=d["dtype"]).reshape(d["shape"]),
d["units"],
)
jrt.registerType(pq.Quantity, quantity_tag, quantityDump, quantityLoad)
return jrt
@overload
def toFile(
self,
file: Path | str | IO[bytes],
suffix: Optional[
Literal[".json", ".yaml", ".yml", ".h5", ".hdf5", ".pickle", ".pck", ".pcl"]
] = None,
) -> None:
"""Write the DataStruct to a file.
Parameters
----------
file : Path | str | IO[bytes]
Specify the file to write to. For each type, the suffix of the filename is used to determine
the type of file to write to, e.g., ".yaml" for a YAML file or ".h5" for an HDF5 file, if the optional
suffix argument is not specified. See below for details. In the case of IO[bytes], if the IO object
has a name attr, that will be used to lookup the suffix.
suffix : Optional[Literal[".json",".yaml",".yml",".h5",".hdf5",".pickle",".pck",".pcl"]]
This optional keyword argument is used to set the file type. If it is not specified, then then file type
is inferred from the name. If it can't be inferred from the name, then YAML is used.
"""
... # pragma: no cover. Will not be run by code coverage.
@overload
def toFile(self, g: h5py.Group) -> None:
"""Write the DataStruct to a H5 group.
Parameters
----------
g : h5py.Group
The group to write to.
"""
... # pragma: no cover. Will not be run by code coverage.
[docs]
def toFile(self, *args, **kwargs):
"""Write the DataStruct to a H5 group.
See overloads for deatils.
"""
if isinstance(group := args[0], h5py.Group):
_h5RecurseSave(group, None, self)
else:
# Figure out the suffix
file = args[0]
suffix = kwargs.get("suffix", None)
if suffix is None:
if isinstance(file, Path):
suffix = file.suffix
elif isinstance(file, str):
suffix = "." + file.split(".")[-1].strip()
elif hasattr(file, "name"):
suffix = "." + file.name.split(".")[-1].strip()
else:
warn("Cannot determine suffix from {file}. Using yaml.")
suffix = ".yaml"
def _write(f, suffix):
if suffix == ".json":
rt = self._jsonRoundTrip()
f.write(dumps(self.model_dump(), cls=rt.CustomEncoder))
elif suffix in [".pickle", ".pck", ".pcl"]:
import pickle
pickle.dump(self, f)
elif suffix in [".hdf5", ".h5"]:
_h5RecurseSave(f, None, self)
else:
yaml = self._ruamelYaml()
yaml.dump(self.model_dump(), f)
if isinstance(file, Path) or isinstance(file, str):
# If this is a string or path, we need to open the file to write it.
# Do so in a context to ensure it is closed.
if suffix in [".hdf5", ".h5"]:
with h5py.File(file, "w") as f:
_write(f, suffix)
elif suffix in [".pickle", ".pck", ".pcl"]:
with open(file, "wb") as f:
_write(f, suffix)
else:
with open(file, "w") as f:
_write(f, suffix)
else:
# This is an already open file. Just write to it.
_write(file, suffix)
@classmethod
@overload
def fromFile(
cls,
file: Path | str | IO[bytes],
suffix: Optional[
Literal[".json", ".yaml", ".yml", ".h5", ".hdf5", ".pickle", ".pck", ".pcl"]
] = None,
) -> Self:
"""Create an instance of this DataStruct from a file.
Parameters
----------
file : Path | str | IO[bytes]
Specify the file to read from. For each type, the suffix of the filename is used to determine
the type of file to read from, e.g., ".yaml" for a YAML file or ".h5" for an HDF5 file, if the optional
suffix argument is not specified. See below for details. In the case of IO[bytes], if the IO object
has a name attr, that will be used to lookup the suffix.
suffix : Optional[Literal[".json",".yaml",".yml",".h5",".hdf5",".pickle",".pck",".pcl"]]
This optional keyword argument is used to set the file type. If it is not specified, then then file type
is inferred from the name. If it can't be inferred from the name, then YAML is used.
Returns
-------
Self
Instance of this DataStruct.
"""
... # pragma: no cover. Will not be run by code coverage.
@classmethod
@overload
def fromFile(cls, g: h5py.Group) -> Self:
"""Create an instance of this DataStruct from a h5py.Group.
Parameters
----------
g : h5py.Group
The group to write to.
Returns
-------
Self
Instance of this DataStruct.
"""
... # pragma: no cover. Will not be run by code coverage.
[docs]
@classmethod
def fromFile(cls, *args, **kwargs) -> Self:
"""Create an instance of this DataStruct a file.
See overloads for details.
"""
if isinstance(group := args[0], h5py.Group):
return cast(Self, _h5RecurseLoad(group))
else:
# Figure out the suffix
file = args[0]
suffix = kwargs.get("suffix", None)
if suffix is None:
if isinstance(file, Path):
suffix = file.suffix
elif isinstance(file, str):
suffix = "." + file.split(".")[-1].strip()
elif hasattr(file, "name"):
suffix = "." + file.name.split(".")[-1].strip()
else:
warn("Cannot determine suffix from {file}. Using yaml.")
suffix = ".yaml"
def _read(f, suffix):
if suffix == ".json":
rt = cls._jsonRoundTrip()
return load(f, object_hook=rt.objectHook)
elif suffix in [".pickle", ".pck", ".pcl"]:
import pickle
return pickle.load(f)
else:
yaml = cls._ruamelYaml()
return yaml.load(f)
if isinstance(file, Path) or isinstance(file, str):
# If this is a string or path, we need to open the file to write it.
# Do so in a context to ensure it is closed.
if suffix in [".hdf5", ".h5"]:
with h5py.File(file, "r") as f:
ret = cast(Self, _h5RecurseLoad(f))
return ret
elif suffix in [".pickle", ".pck", ".pcl"]:
with open(file, "rb") as f:
return _read(f, suffix)
else:
with open(file, "r") as f:
from_dict = _read(f, suffix)
from_dict = cast(dict[str, Any], from_dict)
return cls(**from_dict)
else:
# Otherwise, this is an IO object that is already open. Just read and close.
# No need to handle H5 files here, because they are also H5Groups, so will be
# handled above.
if suffix in [".pickle", ".pck", ".pcl"]:
return _read(file, suffix)
else:
from_dict = _read(file, suffix)
from_dict = cast(dict[str, Any], from_dict)
return cls(**from_dict)
[docs]
def isApprox(self, other: Self, prec: float = MATH_EPSILON) -> bool:
"""Check if this DataStruct is approximately equal to another DataStruct of the same type.
This recursively moves through the public fields only. Note that Pydantic's __eq__
checks private and public fields.
Recursivley here means if the field is an iterator, another DataStruct,
an interator of DataStructs, etc. we will go into the nested structure
calling isApprox where appropriate on the items in the iterator, fields
in the DataStruct, etc. If the field (or iterated value, etc.) does not
have an isApprox method, then we fallback to using the __eq__. If all
values of isApprox (or __eq__) return True, then this returns True;
otherwise, this returns false.
Parameters
----------
other : Self
A DataStruct of the same type.
prec : float
The precision to use. Karana.Math.MATH_EPSILON is the default.
Returns
-------
bool
True if the two DataStructures are approximately equal. False otherwise.
"""
def _isApproxRecurse(val1, val2, prec: float):
# Call isApprox if it exists on the item. Otherwise, call the == method.
# Loop through different iterator types recursively as appropriate
if hasattr(val1, "isApprox"):
return val1.isApprox(val2, prec=prec)
elif isinstance(val1, (tuple, list)):
return all(_isApproxRecurse(x, y, prec) for x, y in zip(val1, val2))
elif isinstance(val1, dict):
if len(val1) != len(val2):
return False
else:
return all(_isApproxRecurse(v, val2[k], prec) for k, v in val1.items())
elif isinstance(val1, pq.Quantity):
# If these are quantities, then just look at the raw values. Leaving the units
# in would result in errors as this compares with things like
# atol + rtol * abs(y), which would change the units to be units**2 for
# rtol * abs(y)
return np.allclose(val1.magnitude, val2.magnitude, rtol=prec, atol=prec)
elif isinstance(val1, np.ndarray):
return np.allclose(val1, val2, rtol=prec, atol=prec)
else:
return val1 == val2
for field, v in iter(self):
if not _isApproxRecurse(v, getattr(other, field), prec):
return False
return True
[docs]
def __eq__(self, other: object) -> bool:
"""Check if this DataStruct is equal to another DataStruct of the same type.
First a normal == is tried on the fields. If that doesn't work,
then we recurse into the fields looking for numpy arrays. Recursivley
here means if the field is an iterator, another DataStruct, an
interator of DataStructs, etc. we will go into the nested structure
calling the appropriate operator on the items in the iterator, fields
in the DataStruct, etc. This is done mainly for numpy arrays, where we
want to call np.array_equal rather than use ==.
Parameters
----------
other : object
The object to compare with.
Returns
-------
bool
True if the two DataStructures are equal. False otherwise.
"""
try:
# First, just try the normal operator. This will work in cases
# where no fields have a numpy array
return super().__eq__(other)
except:
def _eqRecurse(val1, val2):
"""Loop through iterators, DataStructs, etc. to try and find numpy arrays.
Call np.array_equal rather than using __eq__
"""
if isinstance(val1, np.ndarray):
return np.array_equal(val1, val2)
elif isinstance(val1, (tuple, list)):
return all(_eqRecurse(x, y) for x, y in zip(val1, val2))
elif isinstance(val1, dict):
return all(_eqRecurse(v, val2[k]) for k, v in val1.items())
else:
return val1 == val2
# Now, loop through the fields one by one.
for field, v in iter(self):
v2 = getattr(other, field)
try:
# First try the normal equality operator.
if not v == v2:
return False
except:
# If that fails, then recurse through looking through numpy arrays
if not _eqRecurse(v, v2):
return False
# Check private fields too
if not _eqRecurse(
getattr(self, "__pydantic_private__", None),
getattr(other, "__pydantic_private__", None),
):
return False
return True
# Implement a simple singleton that will serve as a Sentinal
[docs]
class SentinalValueClass(object):
"""Dummy class uses as a sentinal value."""
_instance = None
def __new__(cls, *args, **kwargs):
"""Create a new instance of the SentinalValueClass."""
if not cls._instance:
cls._instance = super(SentinalValueClass, cls).__new__(cls, *args, **kwargs)
return cls._instance
SentinalValue = SentinalValueClass()
# Dict linker class for linking the cli and input dictionaries
[docs]
class DSLinker:
"""Class used to populate a CLI application with `DataStruct` fields and create a `DataStruct` instance from CLI options."""
def __init__(
self, data_struct: type[DataStruct], exclude: list[str], extra_info: dict[str, Any]
):
"""Create an DSLinker instance.
DSLinkers are used to link a DataStruct with a Kclick CLI.
Parameters
----------
exclude : list[str]
List of fields to exclude from the cli.
extra_info : dict[str, Any]
Extra information that is used to generate the cli. Examples include:
* name - Name the field.
* help - Description for the field.
* type - Type for the field.
* default - Default for the field.
"""
self.data_struct = data_struct
self.exclude = exclude
self.extra_info = extra_info
# This links the CLI name to the DataStruct name.
self.name_link: dict[str, str] = {}
[docs]
def toCli(self):
"""Use to convert a DataStruct to a Kclick cli."""
# Also, get the docstrings. We parse this information into a dictionary where the keys are
# the parameter names and the values are the description.
if self.data_struct.__doc__ is not None:
doc_string_params = {
name: "\n".join(desc)
for name, _, desc in NumpyDocString(self.data_struct.__doc__)["Parameters"]
}
else:
doc_string_params = {}
def decorator(function):
for name, field_info in self.data_struct.model_fields.items():
# Skip if this is part of the exclude list
if name in self.exclude:
continue
# Get parameter information and add a click option to the function
p_extra_info = self.extra_info.get(name, {})
p_type = p_extra_info.get("type", None)
if p_type is None:
p_type = field_info.annotation
p_default = p_extra_info.get("default", None)
if p_default is None:
p_default = field_info.default
if p_default == PydanticUndefined:
p_default = None
# For the description, if the parameter has a description use that.
# If it doesn't use the docstring information.
# If it doesn't have either, then set its description to an empty string.
p_help = self.extra_info.get("help", None)
if p_help is None:
p_help = field_info.description
if p_help is None:
p_help = doc_string_params.get(name, "")
p_name = self.extra_info.get("name", None)
if p_name is None:
p_name = "--" + name.replace("_", "-")
self.name_link[p_name] = name
function = click.option(
p_name,
type=p_type,
default=p_default,
help=p_help,
required=field_info.is_required(),
)(function)
return function
return decorator
T = TypeVar("T", bound=Base)
# Pydantic models have issues with nesting when __init__ is overridden or dump_model is
# overridden. The easiest way around this, in this case, was to just make _id and
# _objects_from_id private attributes. These are normally not serialized/deserialized,
# so model_serializers and model_validators have been added.
[docs]
class IdMixin(BaseModel, Generic[T]):
"""Mixin to add ID tracking to a DataStruct.
For book keeping, it is common to want to track the id of the
object used to create the DataStruct. This mixin makes it easy to do so.
It adds the private _id variable with default value None, adds a
KaranaId property for it, and ovverrides the appropriate methods so that
_id is serialized/deserialized if set. It is class using the Mixin's job
to add objects to _objects_from_id whenever appropriate.
"""
_id: int | None = PrivateAttr(default=None)
_objects_from_id: list[CppWeakRef[T]] = PrivateAttr(default=[])
@property
def karana_id(self) -> int | None:
"""Retrieve the private _id variable."""
return self._id
@model_serializer(mode="wrap")
def _serializeIdField(self, serializer: Callable[[Any], dict[str, Any]]) -> dict[str, Any]:
"""Perform normal serialization, but add the _id field.
Parameters
----------
serializer : Callable[[Any], dict[str, Any]]
Normal serializer.
Returns
-------
dict[str, Any]
A serialized version of the class, with _id added.
"""
data = serializer(self)
if self._id is not None:
data["_id"] = self._id
return data
@model_validator(mode="wrap")
@classmethod
def _addIdToInstance(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) -> Self:
"""Perform normal validation, but add _id.
Parameters
----------
data : Any
The data for normal validation.
handler : ModelWrapValidatorHandler[Self]
The normal validation function.
Returns
-------
Self
An instance of Self. Validated as normal, but with _id added.
"""
inst = handler(data)
if isinstance(data, dict) and "_id" in data:
inst._id = data["_id"]
return inst
[docs]
def __getstate__(self):
"""Override __getstate__ to include _id if set."""
state = super().__getstate__()
if self._id is not None:
state["__dict__"]["_id"] = self._id
return state
[docs]
@staticmethod
def findObjectsCreatedById(val: Any, id: int) -> list[T] | None:
"""Find any objects in the val data structure that were created for the ID given by id.
This assumes unique IDs (and a unique DataStruct with that ID). Therefore, the search stops
at the first ID.
Parameters
----------
val : Any
The data structure to recurse through. This can be a composite type consisting of DataStructs,
lists, tuples, sets, and dictionaries.
id : int
The ID to use in the search.
Returns
-------
list[T] | None
None if the ID was not found. Otherwise, a list of objects created with the ID.
"""
def _recurse(val: Any, id=id) -> tuple[list[T], bool]:
found_id = False
if isinstance(val, dict):
for k, v in val.items():
dark, found_id = _recurse(k)
if found_id:
return dark, found_id
dark, found_id = _recurse(v)
if found_id:
return dark, found_id
elif isinstance(val, list) or isinstance(val, tuple) or isinstance(val, set):
for v in val:
dark, found_id = _recurse(v)
if found_id:
return dark, found_id
elif isinstance(val, IdMixin):
# We have found the ID. Loop through all the weakrefs. Throw a
# warning if any are not still around. Return those that are.
if val._id == id:
dark: list[T] = []
flag = False
for obj in val._objects_from_id:
ob = cast(T | None, obj())
if ob is not None:
dark.append(ob)
else:
flag = True
if flag:
warn(
"IdMixin found id, but weakrefs returned None. Some weakrefs have been deleted."
)
return dark, True
else:
for _, v in iter(val):
dark, found_id = _recurse(v)
if found_id:
return dark, found_id
elif isinstance(val, DataStruct):
for _, v in iter(val):
dark, found_id = _recurse(v)
if found_id:
return dark, found_id
return [], False
objs, found_id = _recurse(val)
if found_id:
return objs
else:
return None
[docs]
def NestedBaseMixin(field_name: str): # pylint: disable=invalid-name
"""Create a NestedBaseMixin class.
This method returns a class that is a mixin. The mixin is used for classes that will be
nested in other Pydantic models, and that serve as a base class for other Pydantic classes.
An example of this is KModelDS, which is used in StatePropagatorDS. Other types, e.g.,
PointGravityModelDS, will be derived from this. When we serialize/deserialize, we want to save/load
the derived type, not the base type. This should be used in conjuction with SerializeAsAny.
A field will be added to keep track of the type. The name is given by field_name parameters.
Parameters
----------
field_name : str
The name of the field that is automatically added to the base class as part of the mixin
and all derived classes.
Returns
-------
_NestedBaseMixin
An instance of the mixin.
"""
class _NestedBaseMixin(BaseModel):
@classmethod
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
"""Add extra field to the subclass.
For the base class and all derived classes, we want to add a field to keep track
of the specific type, so we can serialize and deserialize correctly.
"""
super().__pydantic_init_subclass__(**kwargs)
name = cls.__module__ + "." + cls.__qualname__
cls.model_fields[field_name] = FieldInfo(default=name, annotation=Literal[name])
cls.model_rebuild(force=True)
@model_validator(mode="wrap")
@classmethod
def _useCorrectType(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) -> Self:
"""Use the correct type when deserializing.
This uses the added field to get the most derived type and deserialize appropriately.
This is provided that SerializeAsAny was used when serializing.
Parameters
----------
data : Any
The data to deserialize from.
handler : ModelWrapValidatorHandler[Self]
The handler used for deserialization of the base class.
Returns
-------
Self
An instance of the most derived type.
"""
if field_name in data:
if (
data[field_name] == cls.__pydantic_fields__.get(field_name, FieldInfo()).default
): # pyright: ignore - false positive
return handler(data)
else:
dark = data[field_name].split(".")
class_str = dark[-1]
module = ".".join(dark[:-1])
mod = import_module(module)
klass = getattr(mod, class_str)
return klass(**data)
else:
return handler(data)
return _NestedBaseMixin