Source code for mesa_frames.concrete.agentset

"""
Polars-based implementation of AgentSet for mesa-frames.

This module provides a concrete implementation of the AgentSet class using Polars
as the backend for DataFrame operations. It defines the AgentSetPolars class,
which combines the abstract AgentSetDF functionality with Polars-specific
operations for efficient agent management and manipulation.

Classes:
    AgentSetPolars(AgentSetDF, PolarsMixin):
        A Polars-based implementation of the AgentSet. This class uses Polars
        DataFrames to store and manipulate agent data, providing high-performance
        operations for large numbers of agents.

The AgentSetPolars class is designed to be used within ModelDF instances or as
part of an AgentsDF collection. It leverages the power of Polars for fast and
efficient data operations on agent attributes and behaviors.

Usage:
    The AgentSetPolars class can be used directly in a model or as part of an
    AgentsDF collection:

    from mesa_frames.concrete.model import ModelDF
    from mesa_frames.concrete.agentset import AgentSetPolars
    import polars as pl

    class MyAgents(AgentSetPolars):
        def __init__(self, model):
            super().__init__(model)
            # Initialize with some agents
            self.add(pl.DataFrame({'id': range(100), 'wealth': 10}))

        def step(self):
            # Implement step behavior using Polars operations
            self.agents = self.agents.with_columns(new_wealth = pl.col('wealth') + 1)

    class MyModel(ModelDF):
        def __init__(self):
            super().__init__()
            self.agents += MyAgents(self)

        def step(self):
            self.agents.step()

Features:
    - Efficient storage and manipulation of large agent populations
    - Fast vectorized operations on agent attributes
    - Support for lazy evaluation and query optimization
    - Seamless integration with other mesa-frames components

Note:
    This implementation relies on Polars, so users should ensure that Polars
    is installed and imported. The performance characteristics of this class
    will depend on the Polars version and the specific operations used.

For more detailed information on the AgentSetPolars class and its methods,
refer to the class docstring.
"""

from __future__ import annotations

from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from typing import Any, Literal, Self, overload

import numpy as np
import polars as pl

from mesa_frames.concrete.agents import AgentSetDF
from mesa_frames.concrete.mixin import PolarsMixin
from mesa_frames.concrete.model import ModelDF
from mesa_frames.types_ import AgentPolarsMask, IntoExpr, PolarsIdsLike
from mesa_frames.utils import copydoc


[docs] @copydoc(AgentSetDF) class AgentSetPolars(AgentSetDF, PolarsMixin): """Polars-based implementation of AgentSetDF.""" _df: pl.DataFrame _copy_with_method: dict[str, tuple[str, list[str]]] = { "_df": ("clone", []), } _copy_only_reference: list[str] = ["_model", "_mask"] _mask: pl.Expr | pl.Series
[docs] def __init__(self, model: mesa_frames.concrete.model.ModelDF) -> None: """Initialize a new AgentSetPolars. Parameters ---------- model : "mesa_frames.concrete.model.ModelDF" The model that the agent set belongs to. """ self._model = model # No definition of schema with unique_id, as it becomes hard to add new agents self._df = pl.DataFrame() self._mask = pl.repeat(True, len(self._df), dtype=pl.Boolean, eager=True)
[docs] def add( self, agents: pl.DataFrame | Sequence[Any] | dict[str, Any], inplace: bool = True, ) -> Self: """Add agents to the AgentSetPolars. Parameters ---------- agents : pl.DataFrame | Sequence[Any] | dict[str, Any] The agents to add. inplace : bool, optional Whether to add the agents in place, by default True. Returns ------- Self The updated AgentSetPolars. """ obj = self._get_obj(inplace) if isinstance(agents, AgentSetDF): raise TypeError( "AgentSetPolars.add() does not accept AgentSetDF objects. " "Extract the DataFrame with agents.agents.drop('unique_id') first." ) elif isinstance(agents, pl.DataFrame): if "unique_id" in agents.columns: raise ValueError("Dataframe should not have a unique_id column.") new_agents = agents elif isinstance(agents, dict): if "unique_id" in agents: raise ValueError("Dictionary should not have a unique_id key.") new_agents = pl.DataFrame(agents) else: # Sequence if len(obj._df) != 0: # For non-empty AgentSet, check column count expected_columns = len(obj._df.columns) - 1 # Exclude unique_id if len(agents) != expected_columns: raise ValueError( f"Length of data ({len(agents)}) must match the number of columns in the AgentSet (excluding unique_id): {expected_columns}" ) new_agents = pl.DataFrame( [list(agents)], schema=[col for col in obj._df.schema if col != "unique_id"], orient="row", ) else: # For empty AgentSet, cannot infer schema from sequence raise ValueError( "Cannot add a sequence to an empty AgentSet. Use a DataFrame or dict with column names." ) new_agents = new_agents.with_columns( self._generate_unique_ids(len(new_agents)).alias("unique_id") ) # If self._mask is pl.Expr, then new mask is the same. # If self._mask is pl.Series[bool], then new mask has to be updated. originally_empty = len(obj._df) == 0 if isinstance(obj._mask, pl.Series) and not originally_empty: original_active_indices = obj._df.filter(obj._mask)["unique_id"] obj._df = pl.concat([obj._df, new_agents], how="diagonal_relaxed") if isinstance(obj._mask, pl.Series) and not originally_empty: obj._update_mask(original_active_indices, new_agents["unique_id"]) return obj
@overload def contains(self, agents: int) -> bool: ... @overload def contains(self, agents: PolarsIdsLike) -> pl.Series: ...
[docs] def contains( self, agents: PolarsIdsLike, ) -> bool | pl.Series: if isinstance(agents, pl.Series): return agents.is_in(self._df["unique_id"]) elif isinstance(agents, Collection) and not isinstance(agents, str): return pl.Series(agents, dtype=pl.UInt64).is_in(self._df["unique_id"]) else: return agents in self._df["unique_id"]
[docs] def get( self, attr_names: IntoExpr | Iterable[IntoExpr] | None, mask: AgentPolarsMask = None, ) -> pl.Series | pl.DataFrame: masked_df = self._get_masked_df(mask) if attr_names is None: # Return all columns except unique_id return masked_df.select(pl.exclude("unique_id")) attr_names = self.df.select(attr_names).columns.copy() if not attr_names: return masked_df masked_df = masked_df.select(attr_names) if masked_df.shape[1] == 1: return masked_df[masked_df.columns[0]] return masked_df
[docs] def set( self, attr_names: str | Collection[str] | dict[str, Any] | None = None, values: Any | None = None, mask: AgentPolarsMask = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) masked_df = obj._get_masked_df(mask) if not attr_names: attr_names = masked_df.columns attr_names.remove("unique_id") def process_single_attr( masked_df: pl.DataFrame, attr_name: str, values: Any ) -> pl.DataFrame: if isinstance(values, pl.DataFrame): values_series = values.to_series() elif isinstance(values, (pl.Expr, pl.Series, Collection)): values_series = pl.Series(values) else: values_series = pl.repeat(values, len(masked_df)) return masked_df.with_columns(values_series.alias(attr_name)) if isinstance(attr_names, str) and values is not None: masked_df = process_single_attr(masked_df, attr_names, values) elif isinstance(attr_names, Collection) and values is not None: if isinstance(values, Collection) and len(attr_names) == len(values): for attribute, val in zip(attr_names, values): masked_df = process_single_attr(masked_df, attribute, val) else: for attribute in attr_names: masked_df = process_single_attr(masked_df, attribute, values) elif isinstance(attr_names, dict): for key, val in attr_names.items(): masked_df = process_single_attr(masked_df, key, val) else: raise ValueError( "attr_names must be a string, a collection of string or a dictionary with columns as keys and values." ) unique_id_column = None if "unique_id" not in obj._df: unique_id_column = self._generate_unique_ids(len(masked_df)).alias( "unique_id" ) obj._df = obj._df.with_columns(unique_id_column) masked_df = masked_df.with_columns(unique_id_column) b_mask = obj._get_bool_mask(mask) non_masked_df = obj._df.filter(b_mask.not_()) original_index = obj._df.select("unique_id") obj._df = pl.concat([non_masked_df, masked_df], how="diagonal_relaxed") obj._df = original_index.join(obj._df, on="unique_id", how="left") obj._update_mask(original_index["unique_id"], unique_id_column) return obj
[docs] def select( self, mask: AgentPolarsMask = None, filter_func: Callable[[Self], pl.Series] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) mask = obj._get_bool_mask(mask) if filter_func: mask = mask & filter_func(obj) if n is not None: mask = (obj._df["unique_id"]).is_in( obj._df.filter(mask).sample(n)["unique_id"] ) if negate: mask = mask.not_() obj._mask = mask return obj
[docs] def shuffle(self, inplace: bool = True) -> Self: obj = self._get_obj(inplace) obj._df = obj._df.sample( fraction=1, shuffle=True, seed=obj.random.integers(np.iinfo(np.int32).max), ) return obj
[docs] def sort( self, by: str | Sequence[str], ascending: bool | Sequence[bool] = True, inplace: bool = True, **kwargs, ) -> Self: obj = self._get_obj(inplace) if isinstance(ascending, bool): descending = not ascending else: descending = [not a for a in ascending] obj._df = obj._df.sort(by=by, descending=descending, **kwargs) return obj
def _concatenate_agentsets( self, agentsets: Iterable[Self], duplicates_allowed: bool = True, keep_first_only: bool = True, original_masked_index: pl.Series | None = None, ) -> Self: if not duplicates_allowed: indices_list = [self._df["unique_id"]] + [ agentset._df["unique_id"] for agentset in agentsets ] all_indices = pl.concat(indices_list) if all_indices.is_duplicated().any(): raise ValueError( "Some ids are duplicated in the AgentSetDFs that are trying to be concatenated" ) if duplicates_allowed & keep_first_only: # Find the original_index list (ie longest index list), to sort correctly the rows after concatenation max_length = max(len(agentset) for agentset in agentsets) for agentset in agentsets: if len(agentset) == max_length: original_index = agentset._df["unique_id"] final_dfs = [self._df] final_active_indices = [self._df["unique_id"]] final_indices = self._df["unique_id"].clone() for obj in iter(agentsets): # Remove agents that are already in the final DataFrame final_dfs.append( obj._df.filter(pl.col("unique_id").is_in(final_indices).not_()) ) # Add the indices of the active agents of current AgentSet final_active_indices.append(obj._df.filter(obj._mask)["unique_id"]) # Update the indices of the agents in the final DataFrame final_indices = pl.concat( [final_indices, final_dfs[-1]["unique_id"]], how="vertical" ) # Left-join original index with concatenated dfs to keep original ids order final_df = original_index.to_frame().join( pl.concat(final_dfs, how="diagonal_relaxed"), on="unique_id", how="left" ) # final_active_index = pl.concat(final_active_indices, how="vertical") else: final_df = pl.concat([obj._df for obj in agentsets], how="diagonal_relaxed") final_active_index = pl.concat( [obj._df.filter(obj._mask)["unique_id"] for obj in agentsets] ) final_mask = final_df["unique_id"].is_in(final_active_index) self._df = final_df self._mask = final_mask # If some ids were removed in the do-method, we need to remove them also from final_df if not isinstance(original_masked_index, type(None)): ids_to_remove = original_masked_index.filter( original_masked_index.is_in(self._df["unique_id"]).not_() ) if not ids_to_remove.is_empty(): self.remove(ids_to_remove, inplace=True) return self def _get_bool_mask( self, mask: AgentPolarsMask = None, ) -> pl.Series | pl.Expr: def bool_mask_from_series(mask: pl.Series) -> pl.Series: if ( isinstance(mask, pl.Series) and mask.dtype == pl.Boolean and len(mask) == len(self._df) ): return mask return self._df["unique_id"].is_in(mask) if isinstance(mask, pl.Expr): return mask elif isinstance(mask, pl.Series): return bool_mask_from_series(mask) elif isinstance(mask, pl.DataFrame): if "unique_id" in mask.columns: return bool_mask_from_series(mask["unique_id"]) elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: return bool_mask_from_series(mask[mask.columns[0]]) else: raise KeyError( "DataFrame must have a 'unique_id' column or a single boolean column." ) elif mask is None or mask == "all": return pl.repeat(True, len(self._df)) elif mask == "active": return self._mask elif isinstance(mask, Collection): return bool_mask_from_series(pl.Series(mask, dtype=pl.UInt64)) else: return bool_mask_from_series(pl.Series([mask], dtype=pl.UInt64)) def _get_masked_df( self, mask: AgentPolarsMask = None, ) -> pl.DataFrame: if (isinstance(mask, pl.Series) and mask.dtype == pl.Boolean) or isinstance( mask, pl.Expr ): return self._df.filter(mask) elif isinstance(mask, pl.DataFrame): if not mask["unique_id"].is_in(self._df["unique_id"]).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) return mask.select("unique_id").join(self._df, on="unique_id", how="left") elif isinstance(mask, pl.Series): if not mask.is_in(self._df["unique_id"]).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) mask_df = mask.to_frame("unique_id") return mask_df.join(self._df, on="unique_id", how="left") elif mask is None or mask == "all": return self._df elif mask == "active": return self._df.filter(self._mask) else: if isinstance(mask, Collection): mask_series = pl.Series(mask, dtype=pl.UInt64) else: mask_series = pl.Series([mask], dtype=pl.UInt64) if not mask_series.is_in(self._df["unique_id"]).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) mask_df = mask_series.to_frame("unique_id") return mask_df.join(self._df, on="unique_id", how="left") @overload def _get_obj_copy(self, obj: pl.Series) -> pl.Series: ... @overload def _get_obj_copy(self, obj: pl.DataFrame) -> pl.DataFrame: ... def _get_obj_copy(self, obj: pl.Series | pl.DataFrame) -> pl.Series | pl.DataFrame: return obj.clone() def _discard(self, ids: PolarsIdsLike) -> Self: mask = self._get_bool_mask(ids) if isinstance(self._mask, pl.Series): original_active_indices = self._df.filter(self._mask)["unique_id"] self._df = self._df.filter(mask.not_()) if isinstance(self._mask, pl.Series): self._update_mask(original_active_indices) return self def _update_mask( self, original_active_indices: pl.Series, new_indices: pl.Series | None = None ) -> None: if new_indices is not None: self._mask = self._df["unique_id"].is_in( original_active_indices ) | self._df["unique_id"].is_in(new_indices) else: self._mask = self._df["unique_id"].is_in(original_active_indices)
[docs] def __getattr__(self, key: str) -> pl.Series: super().__getattr__(key) return self._df[key]
def _generate_unique_ids(self, n: int) -> pl.Series: return pl.Series( self.random.integers(1, np.iinfo(np.uint64).max, size=n, dtype=np.uint64) ) @overload def __getitem__( self, key: str | tuple[AgentPolarsMask, str], ) -> pl.Series: ... @overload def __getitem__( self, key: ( AgentPolarsMask | Collection[str] | tuple[ AgentPolarsMask, Collection[str], ] ), ) -> pl.DataFrame: ...
[docs] def __getitem__( self, key: ( str | Collection[str] | AgentPolarsMask | tuple[AgentPolarsMask, str] | tuple[ AgentPolarsMask, Collection[str], ] ), ) -> pl.Series | pl.DataFrame: attr = super().__getitem__(key) assert isinstance(attr, (pl.Series, pl.DataFrame)) return attr
[docs] def __iter__(self) -> Iterator[dict[str, Any]]: return iter(self._df.iter_rows(named=True))
[docs] def __len__(self) -> int: return len(self._df)
[docs] def __reversed__(self) -> Iterator: return reversed(iter(self._df.iter_rows(named=True)))
@property def df(self) -> pl.DataFrame: return self._df @df.setter def df(self, agents: pl.DataFrame) -> None: if "unique_id" not in agents.columns: raise KeyError("DataFrame must have a unique_id column.") self._df = agents @property def active_agents(self) -> pl.DataFrame: return self.df.filter(self._mask) @active_agents.setter def active_agents(self, mask: AgentPolarsMask) -> None: self.select(mask=mask, inplace=True) @property def inactive_agents(self) -> pl.DataFrame: return self.df.filter(~self._mask) @property def index(self) -> pl.Series: return self._df["unique_id"] @property def pos(self) -> pl.DataFrame: return super().pos