Source code for mesa_frames.concrete.agents

"""
Concrete implementation of the agents collection for mesa-frames.

This module provides the concrete implementation of the agents collection class
for the mesa-frames library. It defines the AgentsDF class, which serves as a
container for all agent sets in a model, leveraging DataFrame-based storage for
improved performance.

Classes:
    AgentsDF(AgentContainer):
        A collection of AgentSetDFs. This class acts as a container for all
        agents in the model, organizing them into separate AgentSetDF instances
        based on their types.

The AgentsDF class is designed to be used within ModelDF instances to manage
all agents in the simulation. It provides methods for adding, removing, and
accessing agents and agent sets, while taking advantage of the performance
benefits of DataFrame-based agent storage.

Usage:
    The AgentsDF class is typically instantiated and used within a ModelDF subclass:

    from mesa_frames.concrete.model import ModelDF
    from mesa_frames.concrete.agents import AgentsDF
    from mesa_frames.concrete import AgentSetPolars

    class MyCustomModel(ModelDF):
        def __init__(self):
            super().__init__()
            # Adding agent sets to the collection
            self.agents += AgentSetPolars(self)
            self.agents += AnotherAgentSetPolars(self)

        def step(self):
            # Step all agent sets
            self.agents.do("step")

Note:
    This concrete implementation builds upon the abstract AgentContainer class
    defined in the mesa_frames.abstract package, providing a ready-to-use
    agents collection that integrates with the DataFrame-based agent storage system.

For more detailed information on the AgentsDF class and its methods, refer to
the class docstring.
"""

from collections import defaultdict
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Literal, cast

import polars as pl
from typing_extensions import Any, Self, overload

from mesa_frames.abstract.agents import AgentContainer, AgentSetDF
from mesa_frames.types_ import (
    AgentMask,
    AgnosticAgentMask,
    BoolSeries,
    DataFrame,
    IdsLike,
    Index,
    Series,
)

if TYPE_CHECKING:
    from mesa_frames.concrete.model import ModelDF


[docs] class AgentsDF(AgentContainer): """A collection of AgentSetDFs. All agents of the model are stored here.""" _agentsets: list[AgentSetDF] _ids: pl.Series
[docs] def __init__(self, model: "ModelDF") -> None: """Initialize a new AgentsDF. Parameters ---------- model : ModelDF The model associated with the AgentsDF. """ self._model = model self._agentsets = [] self._ids = pl.Series(name="unique_id", dtype=pl.Int64)
[docs] def add( self, agents: AgentSetDF | Iterable[AgentSetDF], inplace: bool = True ) -> Self: """Add an AgentSetDF to the AgentsDF. Parameters ---------- agents : AgentSetDF | Iterable[AgentSetDF] The AgentSetDF to add. inplace : bool, optional Whether to add the AgentSetDF in place. Returns ------- Self The updated AgentsDF. Raises ------ ValueError If some agentsets are already present in the AgentsDF or if the IDs are not unique. """ obj = self._get_obj(inplace) other_list = obj._return_agentsets_list(agents) if obj._check_agentsets_presence(other_list).any(): raise ValueError("Some agentsets are already present in the AgentsDF.") new_ids = pl.concat( [obj._ids] + [pl.Series(agentset["unique_id"]) for agentset in other_list] ) if new_ids.is_duplicated().any(): raise ValueError("Some of the agent IDs are not unique.") obj._agentsets.extend(other_list) obj._ids = new_ids return obj
@overload def contains(self, agents: int | AgentSetDF) -> bool: ... @overload def contains(self, agents: IdsLike | Iterable[AgentSetDF]) -> pl.Series: ...
[docs] def contains( self, agents: IdsLike | AgentSetDF | Iterable[AgentSetDF] ) -> bool | pl.Series: if isinstance(agents, int): return agents in self._ids elif isinstance(agents, AgentSetDF): return self._check_agentsets_presence([agents]).any() elif isinstance(agents, Iterable): if len(agents) == 0: return True elif isinstance(next(iter(agents)), AgentSetDF): agents = cast(Iterable[AgentSetDF], agents) return self._check_agentsets_presence(list(agents)) else: # IDsLike agents = cast(IdsLike, agents) return pl.Series(agents).is_in(self._ids)
@overload def do( self, method_name: str, *args, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, return_results: Literal[False] = False, inplace: bool = True, **kwargs, ) -> Self: ... @overload def do( self, method_name: str, *args, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, return_results: Literal[True], inplace: bool = True, **kwargs, ) -> dict[AgentSetDF, Any]: ...
[docs] def do( self, method_name: str, *args, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, return_results: bool = False, inplace: bool = True, **kwargs, ) -> Self | Any: obj = self._get_obj(inplace) agentsets_masks = obj._get_bool_masks(mask) if return_results: return { agentset: agentset.do( method_name, *args, mask=mask, return_results=return_results, **kwargs, inplace=inplace, ) for agentset, mask in agentsets_masks.items() } else: obj._agentsets = [ agentset.do( method_name, *args, mask=mask, return_results=return_results, **kwargs, inplace=inplace, ) for agentset, mask in agentsets_masks.items() ] return obj
[docs] def get( self, attr_names: str | Collection[str] | None = None, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, ) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]: agentsets_masks = self._get_bool_masks(mask) return { agentset: agentset.get(attr_names, mask) for agentset, mask in agentsets_masks.items() }
[docs] def remove( self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike, inplace: bool = True ) -> Self: obj = self._get_obj(inplace) if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): return obj if isinstance(agents, AgentSetDF): agents = [agents] if isinstance(agents, Iterable) and isinstance(next(iter(agents)), AgentSetDF): # We have to get the index of the original AgentSetDF because the copy made AgentSetDFs with different hash ids = [self._agentsets.index(agentset) for agentset in iter(agents)] ids.sort(reverse=True) removed_ids = pl.Series(dtype=pl.Int64) for id in ids: removed_ids = pl.concat( [removed_ids, pl.Series(obj._agentsets[id].index)] ) obj._agentsets.pop(id) else: # IDsLike if isinstance(agents, int): agents = [agents] elif isinstance(agents, DataFrame): agents = agents["unique_id"] removed_ids = pl.Series(agents) deleted = 0 for agentset in obj._agentsets: initial_len = len(agentset) agentset._discard(removed_ids) deleted += initial_len - len(agentset) if deleted == len(removed_ids): break if deleted < len(removed_ids): # TODO: fix type hint raise KeyError( "There exist some IDs which are not present in any agentset" ) try: obj.space.remove_agents(removed_ids, inplace=True) except ValueError: pass obj._ids = obj._ids.filter(obj._ids.is_in(removed_ids).not_()) return obj
[docs] def select( self, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, filter_func: Callable[[AgentSetDF], AgentMask] | None = None, n: int | None = None, inplace: bool = True, negate: bool = False, ) -> Self: obj = self._get_obj(inplace) agentsets_masks = obj._get_bool_masks(mask) if n is not None: n = n // len(agentsets_masks) obj._agentsets = [ agentset.select( mask=mask, filter_func=filter_func, n=n, negate=negate, inplace=inplace ) for agentset, mask in agentsets_masks.items() ] return obj
[docs] def set( self, attr_names: str | dict[AgentSetDF, Any] | Collection[str], values: Any | None = None, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) agentsets_masks = obj._get_bool_masks(mask) if isinstance(attr_names, dict): for agentset, values in attr_names.items(): if not inplace: # We have to get the index of the original AgentSetDF because the copy made AgentSetDFs with different hash id = self._agentsets.index(agentset) agentset = obj._agentsets[id] agentset.set( attr_names=values, mask=agentsets_masks[agentset], inplace=True ) else: obj._agentsets = [ agentset.set( attr_names=attr_names, values=values, mask=mask, inplace=True ) for agentset, mask in agentsets_masks.items() ] return obj
[docs] def shuffle(self, inplace: bool = True) -> Self: obj = self._get_obj(inplace) obj._agentsets = [agentset.shuffle(inplace=True) for agentset in obj._agentsets] return obj
[docs] def sort( self, by: str | Sequence[str], ascending: bool | Sequence[bool] = True, inplace: bool = True, **kwargs, ) -> Self: obj = self._get_obj(inplace) obj._agentsets = [ agentset.sort(by=by, ascending=ascending, inplace=inplace, **kwargs) for agentset in obj._agentsets ] return obj
[docs] def step(self, inplace: bool = True) -> Self: """Advance the state of the agents in the AgentsDF by one step. Parameters ---------- inplace : bool, optional Whether to update the AgentsDF in place, by default True Returns ------- Self """ obj = self._get_obj(inplace) for agentset in obj._agentsets: agentset.step() return obj
def _check_ids_presence(self, other: list[AgentSetDF]) -> pl.DataFrame: """Check if the IDs of the agents to be added are unique. Parameters ---------- other : list[AgentSetDF] The AgentSetDFs to check. Returns ------- pl.DataFrame A DataFrame with the unique IDs and a boolean column indicating if they are present. """ presence_df = pl.DataFrame( data={"unique_id": self._ids, "present": True}, schema={"unique_id": pl.Int64, "present": pl.Boolean}, ) for agentset in other: new_ids = pl.Series(agentset.index) presence_df = pl.concat( [ presence_df, ( new_ids.is_in(presence_df["unique_id"]) .to_frame("present") .with_columns(unique_id=new_ids) .select(["unique_id", "present"]) ), ] ) presence_df = presence_df.slice(self._ids.len()) return presence_df def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: """Check if the agent sets to be added are already present in the AgentsDF. Parameters ---------- other : list[AgentSetDF] The AgentSetDFs to check. Returns ------- pl.Series A boolean Series indicating if the agent sets are present. Raises ------ ValueError If the agent sets are already present in the AgentsDF. """ other_set = set(other) return pl.Series( [agentset in other_set for agentset in self._agentsets], dtype=pl.Boolean ) def _get_bool_masks( self, mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, ) -> dict[AgentSetDF, BoolSeries]: return_dictionary = {} if not isinstance(mask, dict): mask = {agentset: mask for agentset in self._agentsets} for agentset, mask in mask.items(): return_dictionary[agentset] = agentset._get_bool_mask(mask) return return_dictionary def _return_agentsets_list( self, agentsets: AgentSetDF | Iterable[AgentSetDF] ) -> list[AgentSetDF]: """Convert the agentsets to a list of AgentSetDF. Parameters ---------- agentsets : AgentSetDF | Iterable[AgentSetDF] Returns ------- list[AgentSetDF] """ return [agentsets] if isinstance(agentsets, AgentSetDF) else list(agentsets)
[docs] def __add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: """Add AgentSetDFs to a new AgentsDF through the + operator. Parameters ---------- other : AgentSetDF | Iterable[AgentSetDF] The AgentSetDFs to add. Returns ------- Self A new AgentsDF with the added AgentSetDFs. """ return super().__add__(other)
[docs] def __getattr__(self, name: str) -> dict[AgentSetDF, Any]: if name.startswith("_"): # Avoids infinite recursion of private attributes raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) return {agentset: getattr(agentset, name) for agentset in self._agentsets}
@overload def __getitem__( self, key: str | tuple[dict[AgentSetDF, AgentMask], str] ) -> dict[str, Series]: ... @overload def __getitem__( self, key: ( Collection[str] | AgnosticAgentMask | IdsLike | tuple[dict[AgentSetDF, AgentMask], Collection[str]] ), ) -> dict[str, DataFrame]: ...
[docs] def __getitem__( self, key: ( str | Collection[str] | AgnosticAgentMask | IdsLike | tuple[dict[AgentSetDF, AgentMask], str] | tuple[dict[AgentSetDF, AgentMask], Collection[str]] ), ) -> dict[str, Series] | dict[str, DataFrame]: return super().__getitem__(key)
[docs] def __iadd__(self, agents: AgentSetDF | Iterable[AgentSetDF]) -> Self: """Add AgentSetDFs to the AgentsDF through the += operator. Parameters ---------- agents : AgentSetDF | Iterable[AgentSetDF] The AgentSetDFs to add. Returns ------- Self The updated AgentsDF. """ return super().__iadd__(agents)
[docs] def __iter__(self) -> Iterator[dict[str, Any]]: return (agent for agentset in self._agentsets for agent in iter(agentset))
[docs] def __isub__(self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike) -> Self: """Remove AgentSetDFs from the AgentsDF through the -= operator. Parameters ---------- agents : AgentSetDF | Iterable[AgentSetDF] | IdsLike The AgentSetDFs to remove. Returns ------- Self The updated AgentsDF. """ return super().__isub__(agents)
[docs] def __len__(self) -> int: return sum(len(agentset._agents) for agentset in self._agentsets)
[docs] def __repr__(self) -> str: return "\n".join([repr(agentset) for agentset in self._agentsets])
[docs] def __reversed__(self) -> Iterator: return ( agent for agentset in self._agentsets for agent in reversed(agentset._backend) )
[docs] def __setitem__( self, key: ( str | Collection[str] | AgnosticAgentMask | IdsLike | tuple[dict[AgentSetDF, AgentMask], str] | tuple[dict[AgentSetDF, AgentMask], Collection[str]] ), values: Any, ) -> None: super().__setitem__(key, values)
[docs] def __str__(self) -> str: return "\n".join([str(agentset) for agentset in self._agentsets])
[docs] def __sub__(self, agents: IdsLike | AgentSetDF | Iterable[AgentSetDF]) -> Self: """Remove AgentSetDFs from a new AgentsDF through the - operator. Parameters ---------- agents : IdsLike | AgentSetDF | Iterable[AgentSetDF] The AgentSetDFs to remove. Returns ------- Self A new AgentsDF with the removed AgentSetDFs. """ return super().__sub__(agents)
@property def agents(self) -> dict[AgentSetDF, DataFrame]: return {agentset: agentset.agents for agentset in self._agentsets} @agents.setter def agents(self, other: Iterable[AgentSetDF]) -> None: """Set the agents in the AgentsDF. Parameters ---------- other : Iterable[AgentSetDF] The AgentSetDFs to set. """ self._agentsets = list(other) @property def active_agents(self) -> dict[AgentSetDF, DataFrame]: return {agentset: agentset.active_agents for agentset in self._agentsets} @active_agents.setter def active_agents( self, agents: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] ) -> None: self.select(agents, inplace=True) @property def agentsets_by_type(self) -> dict[type[AgentSetDF], Self]: """Get the agent sets in the AgentsDF grouped by type. Returns ------- dict[type[AgentSetDF], Self] A dictionary mapping agent set types to the corresponding AgentsDF. """ def copy_without_agentsets() -> Self: return self.copy(deep=False, skip=["_agentsets"]) dictionary: dict[type[AgentSetDF], Self] = defaultdict(copy_without_agentsets) for agentset in self._agentsets: agents_df = dictionary[agentset.__class__] agents_df._agentsets = [] agents_df._agentsets = agents_df._agentsets + [agentset] dictionary[agentset.__class__] = agents_df return dictionary @property def inactive_agents(self) -> dict[AgentSetDF, DataFrame]: return {agentset: agentset.inactive_agents for agentset in self._agentsets} @property def index(self) -> dict[AgentSetDF, Index]: return {agentset: agentset.index for agentset in self._agentsets} @property def pos(self) -> dict[AgentSetDF, DataFrame]: return {agentset: agentset.pos for agentset in self._agentsets}