Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Attention: The newest changes should be on top -->
- ENH: Monte Carlo Formatting Options [#947](https://github.com/RocketPy-Team/RocketPy/pull/947)
- ENH: ENH: Auto-Detection of Pressure Conversion Factor [#966](https://github.com/RocketPy-Team/RocketPy/pull/966)
- ENH: Auto-Detection of Pressure Conversion Factor [#966](https://github.com/RocketPy-Team/RocketPy/pull/966)
- ENH: Discrete and Continuous Controllers [#946](https://github.com/RocketPy-Team/RocketPy/pull/946)
- ENH: MNT: introduce pressure unit conversion when using forecast/reanalysis/ensemble data [#955](https://github.com/RocketPy-Team/RocketPy/pull/955)
- ENH: Auto Populate Changelog [#919](https://github.com/RocketPy-Team/RocketPy/pull/919)
- ENH: Adaptive Monte Carlo via Convergence Criteria [#922](https://github.com/RocketPy-Team/RocketPy/pull/922)
Expand Down
12 changes: 7 additions & 5 deletions rocketpy/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
self,
interactive_objects,
controller_function,
sampling_rate,
sampling_rate=None,
initial_observed_variables=None,
name="Controller",
):
Expand Down Expand Up @@ -71,12 +71,14 @@ def __init__(
objects as needed. The function return statement can be used to save
relevant information in the `observed_variables` list.

.. note:: The function will be called according to the sampling rate
specified.
sampling_rate : float
.. note:: The function will be called according to the sampling
rate specified. If `sampling_rate` is None, the controller
function is called at every solver step of the simulation.
sampling_rate : float, optional
The sampling rate of the controller function in Hertz (Hz). This
means that the controller function will be called every
`1/sampling_rate` seconds.
`1/sampling_rate` seconds. If None, it is treated as a
continuous controller and called at every solver step.
initial_observed_variables : list, optional
A list of the initial values of the variables that the controller
function returns. This list is used to initialize the
Expand Down
18 changes: 17 additions & 1 deletion rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,16 @@ def __simulate(self, verbose):
self.y_sol = phase.solver.y
if verbose:
print(f"Current Simulation Time: {self.t:3.4f} s", end="\r")

if self._continuous_controllers:
for controller in self._continuous_controllers:
controller(
self.t,
self.y_sol,
self._controller_state_history,
self.sensors,
self.env,
)
self._controller_state_history.append(list(self.y_sol))
Comment on lines +780 to +789
if self.__check_simulation_events(phase, phase_index, node_index):
break # Stop if simulation termination event occurred

Expand Down Expand Up @@ -1537,6 +1546,7 @@ def __init_solver_monitors(self):

self.t_initial = self.initial_solution[0]
self.solution.append(self.initial_solution)
self._controller_state_history = [self.initial_solution[1:]]
self.t = self.solution[-1][0]
self.y_sol = self.solution[-1][1:]

Expand Down Expand Up @@ -1576,6 +1586,9 @@ def __init_equations_of_motion(self):
def __init_controllers(self):
"""Initialize controllers and sensors"""
self._controllers = self.rocket._controllers[:]
self._continuous_controllers = [
c for c in self._controllers if c.sampling_rate is None
]
self.sensors = self.rocket.sensors.get_components()

# reset controllable object to initial state (only airbrakes for now)
Expand Down Expand Up @@ -4488,6 +4501,9 @@ def add_parachutes(self, parachutes, t_init, t_end):

def add_controllers(self, controllers, t_init, t_end):
for controller in controllers:
# Skip node creation for continuous controllers
if controller.sampling_rate is None:
continue
# Calculate start of sampling time nodes
controller_time_step = 1 / controller.sampling_rate
controller_node_list = [
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/simulation/test_flight_time_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
TimeNode.
"""

# from rocketpy.rocket import Parachute, _Controller
from rocketpy.control.controller import _Controller


def test_time_nodes_init(flight_calisto):
Expand Down Expand Up @@ -49,6 +49,32 @@ def test_time_nodes_add_node(flight_calisto):
# TODO: implement this test


def test_time_nodes_add_controllers_skips_continuous_controllers(flight_calisto):
"""Ensure only discrete controllers create time nodes."""
# Arrange
discrete_controller = _Controller(
interactive_objects=[],
controller_function=lambda t, sr, sv, sh, ov, io: None,
sampling_rate=10,
name="Discrete",
)
continuous_controller = _Controller(
interactive_objects=[],
controller_function=lambda t, sr, sv, sh, ov, io: None,
sampling_rate=None,
name="Continuous",
)
time_nodes = flight_calisto.TimeNodes()

# Act
time_nodes.add_controllers([discrete_controller, continuous_controller], 0, 1)

# Assert
assert len(time_nodes) == 11
assert all(node._controllers == [discrete_controller] for node in time_nodes)
assert all(continuous_controller not in node._controllers for node in time_nodes)


def test_time_nodes_sort(flight_calisto):
time_nodes = flight_calisto.TimeNodes()
time_nodes.add_node(3.0, [], [], [])
Expand Down
Loading