Source code for mesa_frames.abstract.mixin

"""
Mixin classes for mesa-frames abstract components.

This module defines mixin classes that provide common functionality and interfaces
for various components in the mesa-frames extension. These mixins are designed to
be used with the abstract base classes to create flexible and extensible
implementations.

Classes:
    CopyMixin(ABC):
        A mixin class that provides a fast copy method for classes that inherit it.
        This is useful for creating efficient copies of large data structures, such
        as DataFrames containing agent data.

    DataFrameMixin(ABC):
        A mixin class that defines an interface for DataFrame operations. This mixin
        provides a common set of methods that should be implemented by concrete
        backend classes (e.g. Polars implementations) to ensure consistent
        DataFrame manipulation across the mesa-frames package.

These mixin classes are not meant to be instantiated directly. Instead, they should
be inherited alongside other base classes to add specific functionality or to
enforce a common interface.

Usage:
    Mixin classes are typically used in multiple inheritance scenarios:

    from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin

    class MyDataFrameClass(SomeBaseClass, CopyMixin, DataFrameMixin):
        def __init__(self):
            super().__init__()
            # Implementation

        # Implement abstract methods from DataFrameMixin

Note:
    The DataFrameMixin uses Python's @abstractmethod decorator for its methods,
    ensuring that classes inheriting from it must implement these methods.

Attributes and methods of each mixin class are documented in their respective
docstrings.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Collection, Hashable, Iterator, Sequence
from copy import copy, deepcopy
from typing import Any, Literal, Self, overload

from mesa_frames.types_ import (
    BoolSeries,
    DataFrame,
    DataFrameInput,
    Index,
    Mask,
    Series,
)


class CopyMixin(ABC):
    """A mixin class that provides a fast copy method for the class that inherits it."""

    _copy_with_method: dict[str, tuple[str, list[str]]] = {}
    _copy_only_reference: list[str] = [
        "_model",
    ]

    @abstractmethod
    def __init__(self): ...

    def copy(
        self,
        deep: bool = False,
        memo: dict | None = None,
        skip: list[str] | None = None,
    ) -> Self:
        """Create a copy of the Class.

        Parameters
        ----------
        deep : bool, optional
            Flag indicating whether to perform a deep copy of the AgentContainer.
            If True, all attributes of the AgentContainer will be recursively copied (except attributes in self._copy_reference_only).
            If False, only the top-level attributes will be copied.
            Defaults to False.
        memo : dict | None, optional
            A dictionary used to track already copied objects during deep copy.
            Defaults to None.
        skip : list[str] | None, optional
            A list of attribute names to skip during the copy process.
            Defaults to None.

        Returns
        -------
        Self
            A new instance of the AgentContainer class that is a copy of the original instance.
        """
        cls = self.__class__
        obj = cls.__new__(cls)

        if skip is None:
            skip = []

        if deep:
            if not memo:
                memo = {}
            memo[id(self)] = obj
            attributes = self.__dict__.copy()
            [
                setattr(obj, k, deepcopy(v, memo))
                for k, v in attributes.items()
                if k not in self._copy_with_method
                and k not in self._copy_only_reference
                and k not in skip
            ]
        else:
            [
                setattr(obj, k, copy(v))
                for k, v in self.__dict__.items()
                if k not in self._copy_with_method
                and k not in self._copy_only_reference
                and k not in skip
            ]

        # Copy attributes with a reference only
        for attr in self._copy_only_reference:
            setattr(obj, attr, getattr(self, attr))

        # Copy attributes with a specified method
        for attr in self._copy_with_method:
            attr_obj = getattr(self, attr)
            attr_copy_method, attr_copy_args = self._copy_with_method[attr]
            setattr(obj, attr, getattr(attr_obj, attr_copy_method)(*attr_copy_args))

        return obj

    def _get_obj(self, inplace: bool) -> Self:
        """Get the object to perform operations on.

        Parameters
        ----------
        inplace : bool
            If inplace, return self. Otherwise, return a copy.

        Returns
        -------
        Self
            The object to perform operations on.
        """
        if inplace:
            return self
        else:
            return deepcopy(self)

    def __copy__(self) -> Self:
        """Create a shallow copy of the AgentContainer.

        Returns
        -------
        Self
            A shallow copy of the AgentContainer.
        """
        return self.copy(deep=False)

    def __deepcopy__(self, memo: dict) -> Self:
        """Create a deep copy of the AgentContainer.

        Parameters
        ----------
        memo : dict
            A dictionary to store the copied objects.

        Returns
        -------
        Self
            A deep copy of the AgentContainer.
        """
        return self.copy(deep=True, memo=memo)


class DataFrameMixin(ABC):
    """A mixin class which defines an interface for DataFrame operations. Most methods are abstract and should be implemented by the concrete backend."""

    def _df_remove(self, df: DataFrame, mask: Mask, index_cols: str) -> DataFrame:
        return self._df_get_masked_df(df, index_cols, mask, negate=True)

    @abstractmethod
    def _df_and(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_add(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_all(
        self,
        df: DataFrame,
        name: str = "all",
        axis: str = "columns",
    ) -> Series: ...

    @abstractmethod
    def _df_column_names(self, df: DataFrame) -> list[str]: ...

    @abstractmethod
    def _df_combine_first(
        self, original_df: DataFrame, new_df: DataFrame, index_cols: str | list[str]
    ) -> DataFrame: ...

    @overload
    @abstractmethod
    def _df_concat(
        self,
        objs: Collection[Series],
        how: Literal["vertical"] = "vertical",
        ignore_index: bool = False,
        index_cols: str | None = None,
    ) -> Series: ...

    @overload
    @abstractmethod
    def _df_concat(
        self,
        objs: Collection[Series],
        how: Literal["horizontal"] = "horizontal",
        ignore_index: bool = False,
        index_cols: str | None = None,
    ) -> DataFrame: ...

    @overload
    @abstractmethod
    def _df_concat(
        self,
        objs: Collection[DataFrame],
        how: Literal["horizontal"] | Literal["vertical"] = "vertical",
        ignore_index: bool = False,
        index_cols: str | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_concat(
        self,
        objs: Collection[DataFrame] | Collection[Series],
        how: Literal["horizontal"] | Literal["vertical"] = "vertical",
        ignore_index: bool = False,
        index_cols: str | None = None,
    ) -> DataFrame | Series: ...

    @abstractmethod
    def _df_contains(
        self,
        df: DataFrame,
        column: str,
        values: Collection[Any],
    ) -> BoolSeries: ...

    @abstractmethod
    def _df_constructor(
        self,
        data: DataFrameInput | None = None,
        columns: list[str] | None = None,
        index: Index | None = None,
        index_cols: str | list[str] | None = None,
        dtypes: dict[str, Any] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_div(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_drop_columns(
        self,
        df: DataFrame,
        columns: str | list[str],
    ) -> DataFrame: ...

    @abstractmethod
    def _df_drop_duplicates(
        self,
        df: DataFrame,
        subset: str | list[str] | None = None,
        keep: Literal["first", "last", False] = "first",
    ) -> DataFrame: ...

    @abstractmethod
    def _df_ge(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_get_bool_mask(
        self,
        df: DataFrame,
        index_cols: str | list[str] | None = None,
        mask: Mask | None = None,
        negate: bool = False,
    ) -> BoolSeries: ...

    @abstractmethod
    def _df_get_masked_df(
        self,
        df: DataFrame,
        index_cols: str | list[str] | None = None,
        mask: Mask | None = None,
        columns: str | list[str] | None = None,
        negate: bool = False,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_groupby_cumcount(
        self, df: DataFrame, by: str | list[str], name: str = "cum_count"
    ) -> Series: ...

    @abstractmethod
    def _df_index(self, df: DataFrame, index_name: str | Collection[str]) -> Index: ...

    @abstractmethod
    def _df_iterator(self, df: DataFrame) -> Iterator[dict[str, Any]]: ...

    @abstractmethod
    def _df_join(
        self,
        left: DataFrame,
        right: DataFrame,
        index_cols: str | list[str] | None = None,
        on: str | list[str] | None = None,
        left_on: str | list[str] | None = None,
        right_on: str | list[str] | None = None,
        how: Literal["left"]
        | Literal["right"]
        | Literal["inner"]
        | Literal["outer"]
        | Literal["cross"] = "left",
        suffix="_right",
    ) -> DataFrame: ...

    @abstractmethod
    def _df_lt(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_mod(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_mul(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    @overload
    def _df_norm(
        self,
        df: DataFrame,
        srs_name: str = "norm",
        include_cols: Literal[False] = False,
    ) -> Series: ...

    @abstractmethod
    @overload
    def _df_norm(
        self,
        df: DataFrame,
        srs_name: str = "norm",
        include_cols: Literal[True] = True,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_norm(
        self,
        df: DataFrame,
        srs_name: str = "norm",
        include_cols: bool = False,
    ) -> Series | DataFrame: ...

    @abstractmethod
    def _df_or(
        self,
        df: DataFrame,
        other: DataFrame | Sequence[float | int],
        axis: Literal["index", "columns"] = "index",
        index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_reindex(
        self,
        df: DataFrame,
        other: Sequence[Hashable] | Index,
        new_index_cols: str | list[str],
        original_index_cols: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_rename_columns(
        self,
        df: DataFrame,
        old_columns: list[str],
        new_columns: list[str],
    ) -> DataFrame: ...

    @abstractmethod
    def _df_reset_index(
        self,
        df: DataFrame,
        index_cols: str | list[str] | None = None,
        drop: bool = False,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_sample(
        self,
        df: DataFrame,
        n: int | None = None,
        frac: float | None = None,
        with_replacement: bool = False,
        shuffle: bool = False,
        seed: int | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_set_index(
        self,
        df: DataFrame,
        index_name: str | Collection[str],
        new_index: Sequence[Hashable] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _df_with_columns(
        self,
        original_df: DataFrame,
        data: DataFrame
        | Series
        | Sequence[Sequence]
        | dict[str, Any]
        | Collection[Any]
        | Any,
        new_columns: str | list[str] | None = None,
    ) -> DataFrame: ...

    @abstractmethod
    def _srs_constructor(
        self,
        data: Collection[Any] | None = None,
        name: str | None = None,
        dtype: Any | None = None,
        index: Collection[Any] | None = None,
    ) -> Series: ...

    @abstractmethod
    def _srs_contains(
        self,
        srs: Collection[Any],
        values: Any | Collection[Any],
    ) -> BoolSeries: ...

    @abstractmethod
    def _srs_range(self, name: str, start: int, end: int, step: int = 1) -> Series: ...

    @abstractmethod
    def _srs_to_df(
        self, srs: Series, index: Collection[Any] | None = None
    ) -> DataFrame: ...