"""
Concrete class for data collection in mesa-frames.
This module defines a `DataCollector` implementation that gathers and optionally persists
model-level and agent-level data during simulations. It supports multiple storage backends,
including in-memory, CSV, Parquet, S3, and PostgreSQL, using Polars for efficient lazy
data processing.
Classes:
DataCollector:
A concrete class defining logic for all data collector implementations.
It supports flexible reporting of model and agent attributes, conditional
data collection using a trigger function, and pluggable backends for storage.
Supported Storage Backends:
- memory : In-memory collection (default)
- csv : Local CSV file output
- parquet : Local Parquet file output
- S3-csv : CSV files stored on Amazon S3
- S3-parquet : Parquet files stored on Amazon S3
- postgresql : PostgreSQL database with schema support
Triggers:
- A `trigger` parameter can be provided to control conditional collection.
This is a callable taking the model as input and returning a boolean.
If true, data is collected during `conditional_collect()`.
Usage:
The `DataCollector` class is designed to be used within a `ModelDF` instance
to collect model-level and/or agent-level data.
Example:
--------
from mesa_frames.concrete.model import ModelDF
from mesa_frames.concrete.datacollector import DataCollector
class ExampleModel(ModelDF):
def __init__(self, agents: AgentsDF):
super().__init__()
self.agents = agents
self.dc = DataCollector(
model=self,
# other required arguments
)
def step(self):
# Option 1: collect immediately
self.dc.collect()
# Option 2: collect based on condition
self.dc.conditional_collect()
# Write the collected data to the destination
self.dc.flush()
"""
import polars as pl
import boto3
from urllib.parse import urlparse
import tempfile
import psycopg2
from mesa_frames.abstract.datacollector import AbstractDataCollector
from typing import Any, Literal
from collections.abc import Callable
from mesa_frames import ModelDF
from psycopg2.extensions import connection
[docs]
class DataCollector(AbstractDataCollector):
[docs]
def __init__(
self,
model: ModelDF,
model_reporters: dict[str, Callable] | None = None,
agent_reporters: dict[str, str | Callable] | None = None,
trigger: Callable[[Any], bool] | None = None,
reset_memory: bool = True,
storage: Literal[
"memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql"
] = "memory",
storage_uri: str | None = None,
schema: str = "public",
max_worker: int = 4,
):
"""
Initialize the DataCollector with configuration options.
Parameters
----------
model : ModelDF
The model object from which data is collected.
model_reporters : dict[str, Callable] | None
Functions to collect data at the model level.
agent_reporters : dict[str, str | Callable] | None
Attributes or functions to collect data at the agent level.
trigger : Callable[[Any], bool] | None
A function(model) -> bool that determines whether to collect data.
reset_memory : bool
Whether to reset in-memory data after flushing. Default is True.
storage : Literal["memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql" ]
Storage backend URI (e.g. 'memory:', 'csv:', 'postgresql:').
storage_uri: str | None
URI or path corresponding to the selected storage backend.
schema: str
Schema name used for PostgreSQL storage.
max_worker : int
Maximum number of worker threads used for flushing collected data asynchronously
"""
super().__init__(
model=model,
model_reporters=model_reporters,
agent_reporters=agent_reporters,
trigger=trigger,
reset_memory=reset_memory,
storage=storage,
max_workers=max_worker,
)
self._writers = {
"csv": self._write_csv_local,
"parquet": self._write_parquet_local,
"S3-csv": self._write_csv_s3,
"S3-parquet": self._write_parquet_s3,
"postgresql": self._write_postgres,
}
self._storage_uri = storage_uri
self._schema = schema
self._current_model_step = None
self._batch_id = None
self._validate_inputs()
def _collect(self):
"""
Collect data from the model and agents for the current step.
This method checks for the presence of model and agent reporters
and calls the appropriate collection routines for each.
"""
if (
self._current_model_step is None
or self._current_model_step != self._model.steps
):
self._current_model_step = self._model.steps
self._batch_id = 0
if self._model_reporters:
self._collect_model_reporters(
current_model_step=self._current_model_step, batch_id=self._batch_id
)
if self._agent_reporters:
self._collect_agent_reporters(
current_model_step=self._current_model_step, batch_id=self._batch_id
)
self._batch_id += 1
def _collect_model_reporters(self, current_model_step: int, batch_id: int):
"""
Collect model-level data using the model_reporters.
Creates a LazyFrame containing the step, seed, and values
returned by each model reporter. Appends the LazyFrame to internal storage.
"""
model_data_dict = {}
model_data_dict["step"] = current_model_step
model_data_dict["seed"] = str(self.seed)
model_data_dict["batch"] = batch_id
for column_name, reporter in self._model_reporters.items():
model_data_dict[column_name] = reporter(self._model)
model_lazy_frame = pl.LazyFrame([model_data_dict])
self._frames.append(("model", current_model_step, batch_id, model_lazy_frame))
def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
"""
Collect agent-level data using the agent_reporters.
Constructs a LazyFrame with one column per reporter and
includes `step` and `seed` metadata. Appends it to internal storage.
"""
agent_data_dict = {}
for col_name, reporter in self._agent_reporters.items():
if isinstance(reporter, str):
for k, v in self._model.agents[reporter].items():
agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v
else:
agent_data_dict[col_name] = reporter(self._model)
agent_lazy_frame = pl.LazyFrame(agent_data_dict)
agent_lazy_frame = agent_lazy_frame.with_columns(
[
pl.lit(current_model_step).alias("step"),
pl.lit(str(self.seed)).alias("seed"),
pl.lit(batch_id).alias("batch"),
]
)
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
@property
def data(self) -> dict[str, pl.DataFrame]:
"""
Retrieve the collected data as eagerly evaluated Polars DataFrames.
Returns
-------
dict[str, pl.DataFrame]
A dictionary with keys "model" and "agent" mapping to concatenated DataFrames of collected data.
"""
model_frames = [
lf.collect() for kind, step, batch_id, lf in self._frames if kind == "model"
]
agent_frames = [
lf.collect() for kind, step, batch_id, lf in self._frames if kind == "agent"
]
return {
"model": pl.concat(model_frames) if model_frames else pl.DataFrame(),
"agent": pl.concat(agent_frames) if agent_frames else pl.DataFrame(),
}
def _flush(self, frames_to_flush: list):
"""
Flush the collected data to the configured external storage backend.
Uses the appropriate writer function based on the specified storage option.
"""
self._writers[self._storage](
uri=self._storage_uri, frames_to_flush=frames_to_flush
)
def _write_csv_local(self, uri: str, frames_to_flush: list):
"""
Write collected data to local CSV files.
Parameters
----------
uri : str
Local directory path to write files into.
frames_to_flush : list
the collected data in the current thread.
"""
for kind, step, batch, df in frames_to_flush:
df.collect().write_csv(f"{uri}/{kind}_step{step}_batch{batch}.csv")
def _write_parquet_local(self, uri: str, frames_to_flush: list):
"""
Write collected data to local Parquet files.
Parameters
----------
uri: str
Local directory path to write files into.
frames_to_flush : list
the collected data in the current thread.
"""
for kind, step, batch, df in frames_to_flush:
df.collect().write_parquet(f"{uri}/{kind}_step{step}_batch{batch}.parquet")
def _write_csv_s3(self, uri: str, frames_to_flush: list):
"""
Write collected data to AWS S3 in CSV format.
Parameters
----------
uri: str
S3 URI (e.g., s3://bucket/path) to upload files to.
frames_to_flush : list
the collected data in the current thread.
"""
self._write_s3(uri=uri, frames_to_flush=frames_to_flush, format_="csv")
def _write_parquet_s3(self, uri: str, frames_to_flush: list):
"""
Write collected data to AWS S3 in Parquet format.
Parameters
----------
uri: str
S3 URI (e.g., s3://bucket/path) to upload files to.
frames_to_flush : list
the collected data in the current thread.
"""
self._write_s3(uri=uri, frames_to_flush=frames_to_flush, format_="parquet")
def _write_s3(self, uri: str, frames_to_flush: list, format_: str):
"""
Upload collected data to S3 in a specified format.
Parameters
----------
uri: str
S3 URI to upload to.
frames_to_flush : list
the collected data in the current thread.
format_: str
Format of the output files ("csv" or "parquet").
"""
s3 = boto3.client("s3")
parsed = urlparse(uri)
bucket = parsed.netloc
prefix = parsed.path.lstrip("/")
for kind, step, batch, lf in frames_to_flush:
df = lf.collect()
with tempfile.NamedTemporaryFile(suffix=f".{format_}") as tmp:
if format_ == "csv":
df.write_csv(tmp.name)
elif format_ == "parquet":
df.write_parquet(tmp.name)
key = f"{prefix}/{kind}_step{step}_batch{batch}.{format_}"
s3.upload_file(tmp.name, bucket, key)
def _write_postgres(self, uri: str, frames_to_flush: list):
"""
Write collected data to a PostgreSQL database.
Each frame is inserted into the appropriate table (`model_data` or `agent_data`)
using batched insert queries.
Parameters
----------
uri: str
PostgreSQL connection URI in the form postgresql://testuser:testpass@localhost:5432/testdb
frames_to_flush : list
the collected data in the current thread.
"""
conn = self._get_db_connection(uri=uri)
cur = conn.cursor()
for kind, step, batch, lf in frames_to_flush:
df = lf.collect()
table = f"{kind}_data"
cols = df.columns
values = [tuple(row) for row in df.rows()]
placeholders = ", ".join(["%s"] * len(cols))
columns = ", ".join(cols)
cur.executemany(
f"INSERT INTO {self._schema}.{table} ({columns}) VALUES ({placeholders})",
values,
)
conn.commit()
cur.close()
conn.close()
def _get_db_connection(self, uri: str) -> connection:
"""
Uri should be like: postgresql://user:pass@host:port/dbname.
Parameters
----------
uri: str
PostgreSQL connection URI in the form postgresql://testuser:testpass@localhost:5432/testdb
Returns
-------
connection
psycopg2 connection
"""
parsed = urlparse(uri)
conn = psycopg2.connect(
dbname=parsed.path[1:], # remove leading slash
user=parsed.username,
password=parsed.password,
host=parsed.hostname,
port=parsed.port,
)
return conn
def _validate_inputs(self):
"""
Validate configuration and required schema for non-memory storage backends.
- Ensures a `storage_uri` is provided if needed.
- For PostgreSQL, validates that required tables and columns exist.
"""
if self._storage != "memory" and self._storage_uri == None:
raise ValueError(
"Please define a storage_uri to if to be stored not in memory"
)
if self._storage == "postgresql":
conn = self._get_db_connection(self._storage_uri)
try:
self._validate_postgress_table_exists(conn)
self._validate_postgress_columns_exists(conn)
finally:
conn.close()
def _validate_postgress_table_exists(self, conn: connection):
"""
Validate that the required PostgreSQL tables exist for storing model and agent data.
Parameters
----------
conn: connection
Open database connection.
"""
if self._model_reporters:
self._validate_reporter_table(conn=conn, table_name="model_data")
if self._agent_reporters:
self._validate_reporter_table(conn=conn, table_name="agent_data")
def _validate_postgress_columns_exists(self, conn: connection):
"""
Validate that required columns are present in the PostgreSQL tables.
Parameters
----------
conn: connection
Open database connection.
"""
if self._model_reporters:
self._validate_reporter_table_columns(
conn=conn, table_name="model_data", reporter=self._model_reporters
)
if self._agent_reporters:
self._validate_reporter_table_columns(
conn=conn, table_name="agent_data", reporter=self._agent_reporters
)
def _validate_reporter_table(self, conn: connection, table_name: str):
"""
Check if a given table exists in the PostgreSQL schema.
Parameters
----------
conn : connection
Open database connection.
table_name : str
Name of the table to check.
Raises
------
ValueError
If the table does not exist in the schema.
"""
query = f"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = '{self._schema}' AND table_name = '{table_name}'
);"""
result = self._execute_query_with_result(conn, query)
if result == [(False,)]:
raise ValueError(
f"{self._schema}.{table_name} does not exist. To store collected data in DB please create a table with required columns"
)
def _validate_reporter_table_columns(
self, conn: connection, table_name: str, reporter: dict[str, Callable | str]
):
"""
Check if the expected columns are present in a given PostgreSQL table.
Parameters
----------
conn : connection
Open database connection.
table_name :str
Name of the table to validate.
reporter : dict[str, Callable | str]
Dictionary of reporters whose keys are expected as columns.
Raises
------
ValueError
If any expected columns are missing from the table.
"""
expected_columns = set()
for col_name, required_column in reporter.items():
if isinstance(required_column, str):
for k, v in self._model.agents[required_column].items():
expected_columns.add(
(col_name + "_" + str(k.__class__.__name__)).lower()
)
else:
expected_columns.add(col_name.lower())
query = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{self._schema}' AND table_name = '{table_name}';
"""
result = self._execute_query_with_result(conn, query)
if not result:
raise ValueError(
f"Could not retrieve columns for table {self._schema}.{table_name}"
)
existing_columns = {row[0] for row in result}
missing_columns = expected_columns - existing_columns
required_columns = {
"step": "Integer",
"seed": "Varchar",
}
missing_required = {
col: col_type
for col, col_type in required_columns.items()
if col not in existing_columns
}
if missing_columns or missing_required:
error_parts = []
if missing_columns:
error_parts.append(f"Missing columns: {sorted(missing_columns)}")
if missing_required:
required_list = [
f"`{col}` column of type ({col_type})"
for col, col_type in missing_required.items()
]
error_parts.append(
"Missing specific columns: " + ", ".join(required_list)
)
raise ValueError(
f"Missing columns in table {self._schema}.{table_name}: "
+ "; ".join(error_parts)
)
def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple]:
"""
Execute a SQL query and return the fetched results.
Parameters
----------
conn : connection
Open database connection.
query : str
SQL query string.
Returns
-------
list[tuple]
Query result rows.
"""
with conn.cursor() as cur:
cur.execute(query)
return cur.fetchall()