Source code for mesa_frames.concrete.polars.agentset

"""
Polars-based implementation of AgentSet for mesa-frames.

This module provides a concrete implementation of the AgentSet class using Polars
as the backend for DataFrame operations. It defines the AgentSetPolars class,
which combines the abstract AgentSetDF functionality with Polars-specific
operations for efficient agent management and manipulation.

Classes:
    AgentSetPolars(AgentSetDF, PolarsMixin):
        A Polars-based implementation of the AgentSet. This class uses Polars
        DataFrames to store and manipulate agent data, providing high-performance
        operations for large numbers of agents.

The AgentSetPolars class is designed to be used within ModelDF instances or as
part of an AgentsDF collection. It leverages the power of Polars for fast and
efficient data operations on agent attributes and behaviors.

Usage:
    The AgentSetPolars class can be used directly in a model or as part of an
    AgentsDF collection:

    from mesa_frames.concrete.model import ModelDF
    from mesa_frames.concrete.polars.agentset import AgentSetPolars
    import polars as pl

    class MyAgents(AgentSetPolars):
        def __init__(self, model):
            super().__init__(model)
            # Initialize with some agents
            self.add(pl.DataFrame({'id': range(100), 'wealth': 10}))

        def step(self):
            # Implement step behavior using Polars operations
            self.agents = self.agents.with_columns(new_wealth = pl.col('wealth') + 1)

    class MyModel(ModelDF):
        def __init__(self):
            super().__init__()
            self.agents += MyAgents(self)

        def step(self):
            self.agents.step()

Features:
    - Efficient storage and manipulation of large agent populations
    - Fast vectorized operations on agent attributes
    - Support for lazy evaluation and query optimization
    - Seamless integration with other mesa-frames components

Note:
    This implementation relies on Polars, so users should ensure that Polars
    is installed and imported. The performance characteristics of this class
    will depend on the Polars version and the specific operations used.

For more detailed information on the AgentSetPolars class and its methods,
refer to the class docstring.
"""

from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from typing import TYPE_CHECKING

import polars as pl
from polars._typing import IntoExpr
from typing_extensions import Any, Self, overload

from mesa_frames.concrete.agents import AgentSetDF
from mesa_frames.concrete.polars.mixin import PolarsMixin
from mesa_frames.types_ import AgentPolarsMask, PolarsIdsLike
from mesa_frames.utils import copydoc

if TYPE_CHECKING:
    from mesa_frames.concrete.model import ModelDF
    from mesa_frames.concrete.pandas.agentset import AgentSetPandas

import numpy as np


[docs] @copydoc(AgentSetDF) class AgentSetPolars(AgentSetDF, PolarsMixin): """Polars-based implementation of AgentSetDF.""" _agents: pl.DataFrame _copy_with_method: dict[str, tuple[str, list[str]]] = { "_agents": ("clone", []), } _copy_only_reference: list[str] = ["_model", "_mask"] _mask: pl.Expr | pl.Series
[docs] def __init__(self, model: "ModelDF") -> None: """Initialize a new AgentSetPolars. Parameters ---------- model : ModelDF The model that the agent set belongs to. """ self._model = model self._agents = pl.DataFrame(schema={"unique_id": pl.Int64}) self._mask = pl.repeat(True, len(self._agents), dtype=pl.Boolean, eager=True)
[docs] def add( self, agents: pl.DataFrame | Sequence[Any] | dict[str, Any], inplace: bool = True, ) -> Self: """Add agents to the AgentSetPolars. Parameters ---------- agents : pl.DataFrame | Sequence[Any] | dict[str, Any] The agents to add. inplace : bool, optional Whether to add the agents in place, by default True. Returns ------- Self The updated AgentSetPolars. """ obj = self._get_obj(inplace) if isinstance(agents, pl.DataFrame): if "unique_id" not in agents.columns: raise KeyError("DataFrame must have a unique_id column.") new_agents = agents elif isinstance(agents, dict): if "unique_id" not in agents: raise KeyError("Dictionary must have a unique_id key.") new_agents = pl.DataFrame(agents) else: if len(agents) != len(obj._agents.columns): raise ValueError( "Length of data must match the number of columns in the AgentSet if being added as a Collection." ) new_agents = pl.DataFrame([agents], schema=obj._agents.schema) if new_agents["unique_id"].dtype != pl.Int64: raise TypeError("unique_id column must be of type int64.") # If self._mask is pl.Expr, then new mask is the same. # If self._mask is pl.Series[bool], then new mask has to be updated. if isinstance(obj._mask, pl.Series): original_active_indices = obj._agents.filter(obj._mask)["unique_id"] obj._agents = pl.concat([obj._agents, new_agents], how="diagonal_relaxed") if isinstance(obj._mask, pl.Series): obj._update_mask(original_active_indices, new_agents["unique_id"]) return obj
@overload def contains(self, agents: int) -> bool: ... @overload def contains(self, agents: PolarsIdsLike) -> pl.Series: ...
[docs] def contains( self, agents: PolarsIdsLike, ) -> bool | pl.Series: if isinstance(agents, pl.Series): return agents.is_in(self._agents["unique_id"]) elif isinstance(agents, Collection): return pl.Series(agents).is_in(self._agents["unique_id"]) else: return agents in self._agents["unique_id"]
[docs] def get( self, attr_names: IntoExpr | Iterable[IntoExpr] | None, mask: AgentPolarsMask = None, ) -> pl.Series | pl.DataFrame: masked_df = self._get_masked_df(mask) attr_names = self.agents.select(attr_names).columns.copy() if not attr_names: return masked_df masked_df = masked_df.select(attr_names) if masked_df.shape[1] == 1: return masked_df[masked_df.columns[0]] return masked_df
[docs] def set( self, attr_names: str | Collection[str] | dict[str, Any] | None = None, values: Any | None = None, mask: AgentPolarsMask = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) b_mask = obj._get_bool_mask(mask) masked_df = obj._get_masked_df(mask) if not attr_names: attr_names = masked_df.columns attr_names.remove("unique_id") def process_single_attr( masked_df: pl.DataFrame, attr_name: str, values: Any ) -> pl.DataFrame: if isinstance(values, pl.DataFrame): return masked_df.with_columns(values.to_series().alias(attr_name)) elif isinstance(values, pl.Expr): return masked_df.with_columns(values.alias(attr_name)) if isinstance(values, pl.Series): return masked_df.with_columns(values.alias(attr_name)) else: if isinstance(values, Collection): values = pl.Series(values) else: values = pl.repeat(values, len(masked_df)) return masked_df.with_columns(values.alias(attr_name)) if isinstance(attr_names, str) and values is not None: masked_df = process_single_attr(masked_df, attr_names, values) elif isinstance(attr_names, Collection) and values is not None: if isinstance(values, Collection) and len(attr_names) == len(values): for attribute, val in zip(attr_names, values): masked_df = process_single_attr(masked_df, attribute, val) else: for attribute in attr_names: masked_df = process_single_attr(masked_df, attribute, values) elif isinstance(attr_names, dict): for key, val in attr_names.items(): masked_df = process_single_attr(masked_df, key, val) else: raise ValueError( "attr_names must be a string, a collection of string or a dictionary with columns as keys and values." ) non_masked_df = obj._agents.filter(b_mask.not_()) original_index = obj._agents.select("unique_id") obj._agents = pl.concat([non_masked_df, masked_df], how="diagonal_relaxed") obj._agents = original_index.join(obj._agents, on="unique_id", how="left") return obj
[docs] def select( self, mask: AgentPolarsMask = None, filter_func: Callable[[Self], pl.Series] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) mask = obj._get_bool_mask(mask) if filter_func: mask = mask & filter_func(obj) if n is not None: mask = (obj._agents["unique_id"]).is_in( obj._agents.filter(mask).sample(n)["unique_id"] ) if negate: mask = mask.not_() obj._mask = mask return obj
[docs] def shuffle(self, inplace: bool = True) -> Self: obj = self._get_obj(inplace) obj._agents = obj._agents.sample( fraction=1, shuffle=True, seed=obj.random.integers(np.iinfo(np.int32).max), ) return obj
[docs] def sort( self, by: str | Sequence[str], ascending: bool | Sequence[bool] = True, inplace: bool = True, **kwargs, ) -> Self: obj = self._get_obj(inplace) if isinstance(ascending, bool): descending = not ascending else: descending = [not a for a in ascending] obj._agents = obj._agents.sort(by=by, descending=descending, **kwargs) return obj
def to_pandas(self) -> "AgentSetPandas": from mesa_frames.concrete.pandas.agentset import AgentSetPandas new_obj = AgentSetPandas(self._model) new_obj._agents = self._agents.to_pandas() if isinstance(self._mask, pl.Series): new_obj._mask = self._mask.to_pandas() else: # self._mask is Expr new_obj._mask = ( self._agents["unique_id"] .is_in(self._agents.filter(self._mask)["unique_id"]) .to_pandas() ) return new_obj def _concatenate_agentsets( self, agentsets: Iterable[Self], duplicates_allowed: bool = True, keep_first_only: bool = True, original_masked_index: pl.Series | None = None, ) -> Self: if not duplicates_allowed: indices_list = [self._agents["unique_id"]] + [ agentset._agents["unique_id"] for agentset in agentsets ] all_indices = pl.concat(indices_list) if all_indices.is_duplicated().any(): raise ValueError( "Some ids are duplicated in the AgentSetDFs that are trying to be concatenated" ) if duplicates_allowed & keep_first_only: # Find the original_index list (ie longest index list), to sort correctly the rows after concatenation max_length = max(len(agentset) for agentset in agentsets) for agentset in agentsets: if len(agentset) == max_length: original_index = agentset._agents["unique_id"] final_dfs = [self._agents] final_active_indices = [self._agents["unique_id"]] final_indices = self._agents["unique_id"].clone() for obj in iter(agentsets): # Remove agents that are already in the final DataFrame final_dfs.append( obj._agents.filter(pl.col("unique_id").is_in(final_indices).not_()) ) # Add the indices of the active agents of current AgentSet final_active_indices.append(obj._agents.filter(obj._mask)["unique_id"]) # Update the indices of the agents in the final DataFrame final_indices = pl.concat( [final_indices, final_dfs[-1]["unique_id"]], how="vertical" ) # Left-join original index with concatenated dfs to keep original ids order final_df = original_index.to_frame().join( pl.concat(final_dfs, how="diagonal_relaxed"), on="unique_id", how="left" ) # final_active_index = pl.concat(final_active_indices, how="vertical") else: final_df = pl.concat( [obj._agents for obj in agentsets], how="diagonal_relaxed" ) final_active_index = pl.concat( [obj._agents.filter(obj._mask)["unique_id"] for obj in agentsets] ) final_mask = final_df["unique_id"].is_in(final_active_index) self._agents = final_df self._mask = final_mask # If some ids were removed in the do-method, we need to remove them also from final_df if not isinstance(original_masked_index, type(None)): ids_to_remove = original_masked_index.filter( original_masked_index.is_in(self._agents["unique_id"]).not_() ) if not ids_to_remove.is_empty(): self.remove(ids_to_remove, inplace=True) return self def _get_bool_mask( self, mask: AgentPolarsMask = None, ) -> pl.Series | pl.Expr: def bool_mask_from_series(mask: pl.Series) -> pl.Series: if ( isinstance(mask, pl.Series) and mask.dtype == pl.Boolean and len(mask) == len(self._agents) ): return mask return self._agents["unique_id"].is_in(mask) if isinstance(mask, pl.Expr): return mask elif isinstance(mask, pl.Series): return bool_mask_from_series(mask) elif isinstance(mask, pl.DataFrame): if "unique_id" in mask.columns: return bool_mask_from_series(mask["unique_id"]) elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: return bool_mask_from_series(mask[mask.columns[0]]) else: raise KeyError( "DataFrame must have a 'unique_id' column or a single boolean column." ) elif mask is None or mask == "all": return pl.repeat(True, len(self._agents)) elif mask == "active": return self._mask elif isinstance(mask, Collection): return bool_mask_from_series(pl.Series(mask)) else: return bool_mask_from_series(pl.Series([mask])) def _get_masked_df( self, mask: AgentPolarsMask = None, ) -> pl.DataFrame: if (isinstance(mask, pl.Series) and mask.dtype == pl.Boolean) or isinstance( mask, pl.Expr ): return self._agents.filter(mask) elif isinstance(mask, pl.DataFrame): if not mask["unique_id"].is_in(self._agents["unique_id"]).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) return mask.select("unique_id").join( self._agents, on="unique_id", how="left" ) elif isinstance(mask, pl.Series): if not mask.is_in(self._agents["unique_id"]).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) mask_df = mask.to_frame("unique_id") return mask_df.join(self._agents, on="unique_id", how="left") elif mask is None or mask == "all": return self._agents elif mask == "active": return self._agents.filter(self._mask) else: if isinstance(mask, Collection): mask_series = pl.Series(mask) else: mask_series = pl.Series([mask]) if not mask_series.is_in(self._agents["unique_id"]).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) mask_df = mask_series.to_frame("unique_id") return mask_df.join(self._agents, on="unique_id", how="left") @overload def _get_obj_copy(self, obj: pl.Series) -> pl.Series: ... @overload def _get_obj_copy(self, obj: pl.DataFrame) -> pl.DataFrame: ... def _get_obj_copy(self, obj: pl.Series | pl.DataFrame) -> pl.Series | pl.DataFrame: return obj.clone() def _discard(self, ids: PolarsIdsLike) -> Self: mask = self._get_bool_mask(ids) if isinstance(self._mask, pl.Series): original_active_indices = self._agents.filter(self._mask)["unique_id"] self._agents = self._agents.filter(mask.not_()) if isinstance(self._mask, pl.Series): self._update_mask(original_active_indices) return self def _update_mask( self, original_active_indices: pl.Series, new_indices: pl.Series | None = None ) -> None: if new_indices is not None: self._mask = self._agents["unique_id"].is_in( original_active_indices ) | self._agents["unique_id"].is_in(new_indices) else: self._mask = self._agents["unique_id"].is_in(original_active_indices)
[docs] def __getattr__(self, key: str) -> pl.Series: super().__getattr__(key) return self._agents[key]
@overload def __getitem__( self, key: str | tuple[AgentPolarsMask, str], ) -> pl.Series: ... @overload def __getitem__( self, key: ( AgentPolarsMask | Collection[str] | tuple[ AgentPolarsMask, Collection[str], ] ), ) -> pl.DataFrame: ...
[docs] def __getitem__( self, key: ( str | Collection[str] | AgentPolarsMask | tuple[AgentPolarsMask, str] | tuple[ AgentPolarsMask, Collection[str], ] ), ) -> pl.Series | pl.DataFrame: attr = super().__getitem__(key) assert isinstance(attr, (pl.Series, pl.DataFrame)) return attr
[docs] def __iter__(self) -> Iterator[dict[str, Any]]: return iter(self._agents.iter_rows(named=True))
[docs] def __len__(self) -> int: return len(self._agents)
[docs] def __reversed__(self) -> Iterator: return reversed(iter(self._agents.iter_rows(named=True)))
@property def agents(self) -> pl.DataFrame: return self._agents @agents.setter def agents(self, agents: pl.DataFrame) -> None: if "unique_id" not in agents.columns: raise KeyError("DataFrame must have a unique_id column.") self._agents = agents @property def active_agents(self) -> pl.DataFrame: return self.agents.filter(self._mask) @active_agents.setter def active_agents(self, mask: AgentPolarsMask) -> None: self.select(mask=mask, inplace=True) @property def inactive_agents(self) -> pl.DataFrame: return self.agents.filter(~self._mask) @property def index(self) -> pl.Series: return self._agents["unique_id"] @property def pos(self) -> pl.DataFrame: return super().pos