"""
Mixin classes for mesa-frames abstract components.
This module defines mixin classes that provide common functionality and interfaces
for various components in the mesa-frames extension. These mixins are designed to
be used with the abstract base classes to create flexible and extensible
implementations.
Classes:
CopyMixin(ABC):
A mixin class that provides a fast copy method for classes that inherit it.
This is useful for creating efficient copies of large data structures, such
as DataFrames containing agent data.
DataFrameMixin(ABC):
A mixin class that defines an interface for DataFrame operations. This mixin
provides a common set of methods that should be implemented by concrete
backend classes (e.g. Polars implementations) to ensure consistent
DataFrame manipulation across the mesa-frames package.
These mixin classes are not meant to be instantiated directly. Instead, they should
be inherited alongside other base classes to add specific functionality or to
enforce a common interface.
Usage:
Mixin classes are typically used in multiple inheritance scenarios:
from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
class MyDataFrameClass(SomeBaseClass, CopyMixin, DataFrameMixin):
def __init__(self):
super().__init__()
# Implementation
# Implement abstract methods from DataFrameMixin
Note:
The DataFrameMixin uses Python's @abstractmethod decorator for its methods,
ensuring that classes inheriting from it must implement these methods.
Attributes and methods of each mixin class are documented in their respective
docstrings.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Collection, Hashable, Iterator, Sequence
from copy import copy, deepcopy
from typing import Any, Literal, Self, overload
from mesa_frames.types_ import (
BoolSeries,
DataFrame,
DataFrameInput,
Index,
Mask,
Series,
)
class CopyMixin(ABC):
"""A mixin class that provides a fast copy method for the class that inherits it."""
_copy_with_method: dict[str, tuple[str, list[str]]] = {}
_copy_only_reference: list[str] = [
"_model",
]
@abstractmethod
def __init__(self): ...
def copy(
self,
deep: bool = False,
memo: dict | None = None,
skip: list[str] | None = None,
) -> Self:
"""Create a copy of the Class.
Parameters
----------
deep : bool, optional
Flag indicating whether to perform a deep copy of the AgentContainer.
If True, all attributes of the AgentContainer will be recursively copied (except attributes in self._copy_reference_only).
If False, only the top-level attributes will be copied.
Defaults to False.
memo : dict | None, optional
A dictionary used to track already copied objects during deep copy.
Defaults to None.
skip : list[str] | None, optional
A list of attribute names to skip during the copy process.
Defaults to None.
Returns
-------
Self
A new instance of the AgentContainer class that is a copy of the original instance.
"""
cls = self.__class__
obj = cls.__new__(cls)
if skip is None:
skip = []
if deep:
if not memo:
memo = {}
memo[id(self)] = obj
attributes = self.__dict__.copy()
[
setattr(obj, k, deepcopy(v, memo))
for k, v in attributes.items()
if k not in self._copy_with_method
and k not in self._copy_only_reference
and k not in skip
]
else:
[
setattr(obj, k, copy(v))
for k, v in self.__dict__.items()
if k not in self._copy_with_method
and k not in self._copy_only_reference
and k not in skip
]
# Copy attributes with a reference only
for attr in self._copy_only_reference:
setattr(obj, attr, getattr(self, attr))
# Copy attributes with a specified method
for attr in self._copy_with_method:
attr_obj = getattr(self, attr)
attr_copy_method, attr_copy_args = self._copy_with_method[attr]
setattr(obj, attr, getattr(attr_obj, attr_copy_method)(*attr_copy_args))
return obj
def _get_obj(self, inplace: bool) -> Self:
"""Get the object to perform operations on.
Parameters
----------
inplace : bool
If inplace, return self. Otherwise, return a copy.
Returns
-------
Self
The object to perform operations on.
"""
if inplace:
return self
else:
return deepcopy(self)
def __copy__(self) -> Self:
"""Create a shallow copy of the AgentContainer.
Returns
-------
Self
A shallow copy of the AgentContainer.
"""
return self.copy(deep=False)
def __deepcopy__(self, memo: dict) -> Self:
"""Create a deep copy of the AgentContainer.
Parameters
----------
memo : dict
A dictionary to store the copied objects.
Returns
-------
Self
A deep copy of the AgentContainer.
"""
return self.copy(deep=True, memo=memo)
class DataFrameMixin(ABC):
"""A mixin class which defines an interface for DataFrame operations. Most methods are abstract and should be implemented by the concrete backend."""
def _df_remove(self, df: DataFrame, mask: Mask, index_cols: str) -> DataFrame:
return self._df_get_masked_df(df, index_cols, mask, negate=True)
@abstractmethod
def _df_and(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_add(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_all(
self,
df: DataFrame,
name: str = "all",
axis: str = "columns",
) -> Series: ...
@abstractmethod
def _df_column_names(self, df: DataFrame) -> list[str]: ...
@abstractmethod
def _df_combine_first(
self, original_df: DataFrame, new_df: DataFrame, index_cols: str | list[str]
) -> DataFrame: ...
@overload
@abstractmethod
def _df_concat(
self,
objs: Collection[Series],
how: Literal["vertical"] = "vertical",
ignore_index: bool = False,
index_cols: str | None = None,
) -> Series: ...
@overload
@abstractmethod
def _df_concat(
self,
objs: Collection[Series],
how: Literal["horizontal"] = "horizontal",
ignore_index: bool = False,
index_cols: str | None = None,
) -> DataFrame: ...
@overload
@abstractmethod
def _df_concat(
self,
objs: Collection[DataFrame],
how: Literal["horizontal"] | Literal["vertical"] = "vertical",
ignore_index: bool = False,
index_cols: str | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_concat(
self,
objs: Collection[DataFrame] | Collection[Series],
how: Literal["horizontal"] | Literal["vertical"] = "vertical",
ignore_index: bool = False,
index_cols: str | None = None,
) -> DataFrame | Series: ...
@abstractmethod
def _df_contains(
self,
df: DataFrame,
column: str,
values: Collection[Any],
) -> BoolSeries: ...
@abstractmethod
def _df_constructor(
self,
data: DataFrameInput | None = None,
columns: list[str] | None = None,
index: Index | None = None,
index_cols: str | list[str] | None = None,
dtypes: dict[str, Any] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_div(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_drop_columns(
self,
df: DataFrame,
columns: str | list[str],
) -> DataFrame: ...
@abstractmethod
def _df_drop_duplicates(
self,
df: DataFrame,
subset: str | list[str] | None = None,
keep: Literal["first", "last", False] = "first",
) -> DataFrame: ...
@abstractmethod
def _df_ge(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_get_bool_mask(
self,
df: DataFrame,
index_cols: str | list[str] | None = None,
mask: Mask | None = None,
negate: bool = False,
) -> BoolSeries: ...
@abstractmethod
def _df_get_masked_df(
self,
df: DataFrame,
index_cols: str | list[str] | None = None,
mask: Mask | None = None,
columns: str | list[str] | None = None,
negate: bool = False,
) -> DataFrame: ...
@abstractmethod
def _df_groupby_cumcount(
self, df: DataFrame, by: str | list[str], name: str = "cum_count"
) -> Series: ...
@abstractmethod
def _df_index(self, df: DataFrame, index_name: str | Collection[str]) -> Index: ...
@abstractmethod
def _df_iterator(self, df: DataFrame) -> Iterator[dict[str, Any]]: ...
@abstractmethod
def _df_join(
self,
left: DataFrame,
right: DataFrame,
index_cols: str | list[str] | None = None,
on: str | list[str] | None = None,
left_on: str | list[str] | None = None,
right_on: str | list[str] | None = None,
how: Literal["left"]
| Literal["right"]
| Literal["inner"]
| Literal["outer"]
| Literal["cross"] = "left",
suffix="_right",
) -> DataFrame: ...
@abstractmethod
def _df_lt(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_mod(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_mul(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
@overload
def _df_norm(
self,
df: DataFrame,
srs_name: str = "norm",
include_cols: Literal[False] = False,
) -> Series: ...
@abstractmethod
@overload
def _df_norm(
self,
df: DataFrame,
srs_name: str = "norm",
include_cols: Literal[True] = True,
) -> DataFrame: ...
@abstractmethod
def _df_norm(
self,
df: DataFrame,
srs_name: str = "norm",
include_cols: bool = False,
) -> Series | DataFrame: ...
@abstractmethod
def _df_or(
self,
df: DataFrame,
other: DataFrame | Sequence[float | int],
axis: Literal["index", "columns"] = "index",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_reindex(
self,
df: DataFrame,
other: Sequence[Hashable] | Index,
new_index_cols: str | list[str],
original_index_cols: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_rename_columns(
self,
df: DataFrame,
old_columns: list[str],
new_columns: list[str],
) -> DataFrame: ...
@abstractmethod
def _df_reset_index(
self,
df: DataFrame,
index_cols: str | list[str] | None = None,
drop: bool = False,
) -> DataFrame: ...
@abstractmethod
def _df_sample(
self,
df: DataFrame,
n: int | None = None,
frac: float | None = None,
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_set_index(
self,
df: DataFrame,
index_name: str | Collection[str],
new_index: Sequence[Hashable] | None = None,
) -> DataFrame: ...
@abstractmethod
def _df_with_columns(
self,
original_df: DataFrame,
data: DataFrame
| Series
| Sequence[Sequence]
| dict[str, Any]
| Collection[Any]
| Any,
new_columns: str | list[str] | None = None,
) -> DataFrame: ...
@abstractmethod
def _srs_constructor(
self,
data: Collection[Any] | None = None,
name: str | None = None,
dtype: Any | None = None,
index: Collection[Any] | None = None,
) -> Series: ...
@abstractmethod
def _srs_contains(
self,
srs: Collection[Any],
values: Any | Collection[Any],
) -> BoolSeries: ...
@abstractmethod
def _srs_range(self, name: str, start: int, end: int, step: int = 1) -> Series: ...
@abstractmethod
def _srs_to_df(
self, srs: Series, index: Collection[Any] | None = None
) -> DataFrame: ...