import logging
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
from pydrake.all import (
Context,
OutputPort,
PortDataType,
MultibodyPlant,
)
from pydrake.systems.primitives import VectorLogSink
from typing import List, Tuple
# Internal Imports
from brom_drake.PortWatcher.file_manager import PortWatcherFileManager
from brom_drake.PortWatcher.port_figure_arrangement import PortFigureArrangement
from brom_drake.PortWatcher.port_watcher_options import FigureNamingConvention, PortWatcherPlottingOptions
from brom_drake.utils.constants import SupportedLogger
from brom_drake.utils.type_checking import is_rigid_transform
from brom_drake.directories import DEFAULT_PLOT_DIR
[docs]
class PortWatcherPlotter:
"""
*Description*
A plotter for the PortWatcher object. This is responsible for interpreting
the data from the :py:class:`brom_drake.PortWatcher.PortWatcher`
TODO(Kwesi): Consider making this a dataclass
"""
def __init__(
self,
port: OutputPort,
python_logger: logging.Logger,
file_manager: PortWatcherFileManager,
plotting_options: PortWatcherPlottingOptions = PortWatcherPlottingOptions(),
):
# Setup
self.port = port
self.plotting_options = plotting_options
self.file_manager = file_manager
self.python_logger = python_logger
[docs]
def compute_plot_shape(self, n_dims: int) -> Tuple[int, int]:
"""
*Description*
Computes the shape of the plot based on the data.
*Parameters*
n_dims: int
The number of dimensions in the data.
*Returns*
n_rows: int
The number of rows in the plot.
n_columns: int
The number of columns in the plot
"""
if n_dims == 1:
return 1, 1
if n_dims == 2:
return 1, 2
if n_dims < 9:
return 2, int(np.ceil(n_dims / 2.0))
# Otherwise
return 3, int(np.ceil(n_dims / 3.0))
[docs]
def add_to_python_report(self, message: str):
"""
Description
-----------
Logs a message to the Python logger.
:param message: A string with the message we want to send to the logs.
:return:
"""
self.python_logger.info(message)
def add_warning_to_python_report(self, message: str):
self.python_logger.warning(message)
[docs]
def data_dimension(self) -> int:
"""
*Description*
Returns the dimension of the data in the port.
*Parameters*
self : PortWatcherPlotter
The PortWatcherPlotter object.
*Returns*
data_dim: int
The dimension of the data in the port.
"""
if self.port.get_data_type() == PortDataType.kVectorValued:
return self.port.size()
else:
# If port contains RigidTransform, then the expected data dimension is 7.
example_allocation = self.port.Allocate()
example_value = example_allocation.get_value()
if is_rigid_transform(example_value):
return 7
elif type(example_value) == bool: # if the output_value is a boolean
return 1
# Otherwise, raise an error
raise ValueError(
f"Port {self.port.get_name()} of system {self.port.get_system().get_name()} is not of the correct type for plotting."
)
[docs]
def plot_logger_data(
self,
drake_vector_log: VectorLogSink,
diagram_context: Context,
) -> Tuple[List[plt.Figure], List[List[plt.Axes]]]:
"""
*Description*
This function plots the data in the logger.
*Parameters*
self : PortWatcherPlotter
The PortWatcherPlotter object.
diagram_context : Context
The context of the diagram.
*Returns*
Tuple[List[plt.Figure], List[List[plt.Axes]]]
A tuple where:
- the first element is a list of figures and
- the second element is a list of lists of axes.
"""
# Setup
python_logger = self.python_logger
plotting_options = self.plotting_options
system = self.port.get_system()
# Get the log from the drake_vector_log
temp_log = drake_vector_log.FindLog(diagram_context)
log_times = temp_log.sample_times()
log_data = temp_log.data()
if (log_data.shape[1] == 0) or (log_data.shape[0] == 0):
python_logger.warning(
f"No data found for {system.get_name()} - Port {self.port.get_name()}! Skipping...")
return None, None
# Plot the data
if plotting_options.plot_arrangement == PortFigureArrangement.OnePlotPerPort:
fig, ax_list = self.plot_logger_data_subplots(log_times, log_data, drake_vector_log)
return [fig], [ax_list]
elif plotting_options.plot_arrangement == PortFigureArrangement.OnePlotPerDim:
figs, ax_grid = [], []
for port_index in range(self.port.size()):
fig_ii = plt.figure()
ax_ii = fig_ii.add_subplot(1, 1, 1)
ax_ii.plot(
log_times, log_data[port_index, :]
)
# Add axis titles and labels
ax_ii.set_xlabel("Time (s)")
ax_ii.set_title(self.file_manager.name_of_data_at_index(port_index, self.port, drake_vector_log))
# Save figures and axes to lists
figs.append(fig_ii)
ax_grid.append([ax_ii])
return figs, ax_grid
else:
raise ValueError(
f"Invalid plot arrangement: {plotting_options.plot_arrangement}."
)
[docs]
def plot_logger_data_subplots(
self,
times: np.array,
data: np.array,
drake_vector_log: VectorLogSink,
):
"""
*Description*
This function plots the data in the logger.
TODO(Kwesi): Consider adding refactoring this to remove dependecny on drake_vector_log
*Parameters*
self : PortWatcherPlotter
The PortWatcherPlotter object.
times : np.array
The times at which the data was recorded.
data : np.array
The data that was recorded.
*Returns*
fig_out: plt.Figure
The figure containing subplots that we use in the output.
axes_list: list[plt.Axes]
The list of axes.
"""
# Setup
n_dims = data.shape[0]
# Plot the data
n_rows, n_cols = self.compute_plot_shape(n_dims)
self.add_to_python_report(
f"Plotting {n_dims} dimensions in a {n_rows}x{n_cols} grid."
)
fig, ax_list = plt.subplots(n_rows, n_cols)
if n_rows == 1 and n_cols == 1:
ax_list.plot(times, data[0, :])
ax_list.set_title(self.file_manager.name_of_data_at_index(0, self.port, drake_vector_log))
elif (n_rows == 1) or (n_cols == 1):
for dim_index in range(n_dims):
ax_list[dim_index].plot(times, data[dim_index, :])
ax_list[dim_index].set_title(
self.file_manager.name_of_data_at_index(dim_index, self.port, drake_vector_log),
)
else:
for row_index in range(n_rows):
for col_index in range(n_cols):
dim_index = n_cols * row_index + col_index
if dim_index >= n_dims:
fig.delaxes(ax_list[row_index, col_index])
continue
ax_list[row_index, col_index].plot(times, data[dim_index, :])
ax_list[row_index, col_index].set_title(self.file_manager.name_of_data_at_index(dim_index, self.port, drake_vector_log))
return fig, ax_list
# Close all figures when done (this should be redundant?)
# plt.close('all')
[docs]
def system_is_multibody_plant(self) -> bool:
"""
*Description*
Returns True if the system is a MultibodyPlant.
*Parameters*
self : PortWatcherPlotter
The PortWatcherPlotter object.
*Returns*
is_plant: bool
True if the system is a MultibodyPlant.
"""
# Setup
system = self.port.get_system()
# Return
return type(system) == MultibodyPlant