Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Conditions
Condition <condition/condition.rst>
Data Condition <condition/data_condition.rst>
Domain Equation Condition <condition/domain_equation_condition.rst>
Graph Time Series Condition <condition/graph_time_series_condition.rst>
Input Equation Condition <condition/input_equation_condition.rst>
Input Target Condition <condition/input_target_condition.rst>

Expand Down
9 changes: 9 additions & 0 deletions docs/source/_rst/condition/graph_time_series_condition.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Graph Time Series Condition
===========================
.. currentmodule:: pina.condition.graph_time_series_condition

.. automodule:: pina._src.condition.graph_time_series_condition

.. autoclass:: pina._src.condition.graph_time_series_condition.GraphTimeSeriesCondition
:members:
:show-inheritance:
8 changes: 8 additions & 0 deletions pina/_src/condition/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.condition.graph_time_series_condition import (
GraphTimeSeriesCondition,
)
from pina._src.condition.data_condition import DataCondition
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
Expand Down Expand Up @@ -111,6 +114,11 @@ class Condition:
{"input", "n_windows", "unroll_length"},
{"randomize"},
),
(
GraphTimeSeriesCondition,
{"input", "n_windows", "unroll_length"},
{"key", "randomize"},
),
)

# Compute the set of all available keyword arguments (optional + required)
Expand Down
147 changes: 147 additions & 0 deletions pina/_src/condition/graph_time_series_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Module for the TimeSeriesCondition class."""

import torch
from pina._src.core.utils import check_consistency, check_positive_integer
from pina._src.data.manager.data_manager import _DataManager
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.core.label_tensor import LabelTensor
from pina._src.condition.base_condition import BaseCondition
from torch_geometric.data import Data
from pina._src.core.graph import Graph


class GraphTimeSeriesCondition(TimeSeriesCondition):
"""
The :class:`TimeSeriesCondition` class represents an autoregressive time
series condition defined by temporal ``input`` data. The input is expected
to have shape ``[trajectories, time_steps, *features]``, where the second
dimension corresponds to the temporal evolution of each trajectory.

During training, the condition automatically extracts overlapping temporal
windows from the trajectories. The parameter ``unroll_length`` defines the
number of consecutive time steps contained in each temporal window, while
``n_windows`` controls how many temporal windows are created from the
available trajectories.

Internally, the unrolled data is stored as a tensor of shape
``[trajectories, n_windows, unroll_length, *features]``.

Supported data types include :class:`~pina.label_tensor.LabelTensor` and
:class:`torch.Tensor`.

:Example:

>>> from pina import Condition, LabelTensor
>>> import torch

>>> data = LabelTensor(torch.rand(5, 10, 2), labels=["u", "v"])
>>> condition = Condition(input=data, unroll_length=5, n_windows=3)
"""

# Available fields and input data types
__fields__ = ["input", "unroll_length", "n_windows", "key", "randomize"]
_avail_input_cls = (Data, Graph)

def __new__(cls, input, n_windows, unroll_length, key='x', randomize=False):
# Check consistency
check_consistency(input, cls._avail_input_cls)
check_consistency(randomize, bool)
check_consistency(key, str)
check_positive_integer(n_windows, strict=True)
check_positive_integer(unroll_length, strict=True)

return BaseCondition.__new__(cls)

def store_data(self, **kwargs):
"""
Store the unrolled time-series input data.

The method extracts the time-series input data and creates the temporal
windows based on the specified ``unroll_length`` and ``n_windows``.

:param dict kwargs: The keyword arguments containing the data to be
stored.
:return: A dictionary-like structure containing the stored data.
:rtype: _DataManager
"""
# Extract unrolling parameters from kwargs
unroll_length = kwargs.get("unroll_length")
n_windows = kwargs.get("n_windows")
randomize = kwargs.get("randomize", False)
key = kwargs.get("key", "x")
graph = kwargs.get("input")

# Create unrolled windows from the input data
if not hasattr(graph, key):
raise ValueError(
f"The provided graph does not have the specified key '{key}'."
)

unrolled_data = self._unroll(
data=graph.__getattribute__(key),
n_windows=n_windows,
unroll_length=unroll_length,
randomize=randomize,
)
graph.__setattr__(key, unrolled_data)

return _DataManager(input=graph)

def evaluate(self, batch, solver):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method appears to be identical to the one defined in TimeSeriesCondition and should therefore be inherited without modification.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now some small differences in accessing the node data, I would keep it otherwise several if are needed in the TimeSerieCondition

"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the per-step residuals through autoregressive
unrolling. A forward pass of the solver's model is performed at each
time step, and the per-step residuals (predicted - target) are
returned as a stacked tensor.

The returned tensor preserves all per-step residual values without
reduction or loss aggregation.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:raises ValueError: If the input tensor in the batch has less than 4
dimensions.
:return: The stacked per-step residual tensor of shape
[time_steps - 1, trajectories, windows, *features].
:rtype: torch.Tensor | LabelTensor
"""
# Raise error if input tensor does not have at least 4 dimensions
if batch["input"].x.dim() < 4:
raise ValueError(
"The provided input tensor must have at least 4 dimensions:"
" [trajectories, windows, time_steps, *features]."
f" Got shape {batch['input'].shape}."
)

# Copy the kwargs to avoid modifying the original settings
kwargs = solver._kwargs.copy()

# Extract the initial state and initialize the step-wise residuals list
current_state = batch["input"].x[:, :, 0, :]
residuals = []

# Iterate over the time steps
for step in range(1, batch["input"].x.shape[2]):

# Pre-process, forward, and post-process the current state
processed_input = solver.preprocess_step(current_state, **kwargs)
output = solver.forward(processed_input)
predicted_state = solver.postprocess_step(output, **kwargs)

# Retrieve the target and compute the step-wise residual
target_state = batch["input"].x[:, :, step, :]
step_residual = predicted_state - target_state
residuals.append(step_residual)

# Update the current state for the next iteration
current_state = predicted_state

# Stack the step-wise residuals
return torch.stack(residuals).as_subclass(torch.Tensor)
9 changes: 8 additions & 1 deletion pina/_src/condition/time_series_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,14 @@ def _unroll(self, data, n_windows, unroll_length, randomize):
# Create unroll windows by slicing the input data at the starting idx
windows = [data[:, s : s + unroll_length] for s in start_indices]

return torch.stack(windows, dim=1)
if isinstance(data, LabelTensor):
# Preserve labels if the input data is a LabelTensor
unrolled_data = torch.stack(windows, dim=1).as_subclass(LabelTensor)
unrolled_data.labels = data.labels
else:
unrolled_data = torch.stack(windows, dim=1)

return unrolled_data

def evaluate(self, batch, solver):
"""
Expand Down
2 changes: 1 addition & 1 deletion pina/_src/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _check_type_consistency(self, **kwargs):
self._check_edge_index_consistency(edge_index)
if "x" in kwargs:
x = kwargs["x"]
self._check_x_consistency(x, pos)
# self._check_x_consistency(x, pos)
if "edge_attr" in kwargs:
edge_attr = kwargs["edge_attr"]
self._check_edge_attr_consistency(edge_attr, edge_index)
Expand Down
2 changes: 2 additions & 0 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"InputEquationCondition",
"DataCondition",
"TimeSeriesCondition",
"GraphTimeSeriesCondition",
]

from pina._src.condition.condition_interface import ConditionInterface
Expand All @@ -27,3 +28,4 @@
from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.data_condition import DataCondition
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.condition.graph_time_series_condition import GraphTimeSeriesCondition
Loading
Loading