Source code for brom_drake.PortWatcher.file_manager

from brom_drake.PortWatcher.port_watcher_options import (
    PortFigureArrangement,
    FigureNamingConvention,
    PortWatcherPlottingOptions, 
)
from dataclasses import dataclass
from pathlib import Path
from pydrake.multibody.plant import MultibodyPlant
from pydrake.systems.framework import OutputPort
from pydrake.systems.primitives import VectorLogSink
from typing import List

[docs] @dataclass class PortWatcherFileManager: """ *Description* This class manages file paths and directories for saving data collected by the PortWatcher system. """ base_directory: Path plotting_options: PortWatcherPlottingOptions raw_data_file_format: str = "npy"
[docs] @staticmethod def compute_safe_system_name(system_name: str) -> str: """ *Description* This function returns a filesystem-safe version of the system name. *Returns* safe_system_name: str The filesystem-safe version of the system name. """ # First, let's check to see how many "/" exist in the name slash_occurences = [i for i, letter in enumerate(system_name) if letter == "/"] if len(slash_occurences) > 0: system_name = system_name[slash_occurences[-1] + 1:] # truncrate string based on the last slash # Second, replace all spaces with underscores system_name = system_name.replace(" ", "_") return system_name
[docs] def compute_path_for_each_figure( self, output_port: OutputPort, associated_log_sink: VectorLogSink, port_component_name: str = None ) -> List[Path]: """ *Description* Computes the names of all of the figures that will be produced for *Returns* figure_paths_out: List[Path] The paths of all of the figures that will be produced by this PortWatcherPlotter object. """ # Setup plotting_options = self.plotting_options # Create the figure paths based on the naming convention given to the # PortWatcherPlotter. match plotting_options.figure_naming_convention: case FigureNamingConvention.kFlat: return self.figure_names_under_flat_convention(output_port, associated_log_sink, port_component_name) case FigureNamingConvention.kHierarchical: return self.figure_names_under_hierarchical_convention(output_port, associated_log_sink, port_component_name) case _: raise NotImplementedError( f"Invalid figure naming convention for figure_names(): {plotting_options.figure_naming_convention}." )
[docs] def figure_names_under_flat_convention( self, output_port: OutputPort, associated_log_sink: VectorLogSink, port_component_name: str = None ) -> List[Path]: """ *Description* Returns the names associated with each figure that this port will generate assuming we are under the kFlat convention. *Arguments* output_port: OutputPort The output port for which we are generating figure names. We can extract the system name, port name, and port size from this. port_component_name: str A "sub-component" of the port that we wish to give a unique name in the figures. *Returns* figure_names: List[Path] List of paths where each path is a file name for an associated figure. """ # Setup plotting_options = self.plotting_options format = plotting_options.file_format plot_dir = self.plot_dir # The naming also depends on the arrangement of the plots # (i.e., if there is one plot per port, or one plot per dimension) system = output_port.get_system() system_name = system.get_name() safe_system_name = self.compute_safe_system_name(system_name) port_name = output_port.get_name() log_sink_size = associated_log_sink.get_input_port().size() match plotting_options.plot_arrangement: case PortFigureArrangement.OnePlotPerPort: if port_component_name is None: # If there is no sub-component name, then we just # create the file in the main plot directory return [ plot_dir / f"system_{safe_system_name}_port_{port_name}.{format}" ] else: # If there is a sub-component name, then we will # create a sub-directory for it return [ plot_dir / f"system_{safe_system_name}_port_{port_name}" / f"{port_component_name}.{format}" ] case PortFigureArrangement.OnePlotPerDim: if port_component_name is None: # If there is no sub-component name, then we just # create the files in the main plot directory return [ plot_dir / f"system_{safe_system_name}_port_{port_name}_dim{ii}.{format}" for ii in range(log_sink_size) ] else: # If there is a sub-component name, then we will # create a sub-directory for it return [ plot_dir / f"system_{safe_system_name}_port_{port_name}" / f"{port_component_name}_dim{ii}.{format}" for ii in range(log_sink_size) ] case _: raise NotImplementedError( f"Invalid plot arrangement for figure naming convention {plotting_options.figure_naming_convention}: {plotting_options.plot_arrangement}." )
[docs] def figure_names_under_hierarchical_convention( self, output_port: OutputPort, associated_log_sink: VectorLogSink, port_component_name: str = None, ) -> List[Path]: """ *Description* Returns the names associated with each figure that this port will generate assuming we are under the kHierarchical convention. *Parameters* output_port: OutputPort The output port for which we are generating figure names. We can extract the system name, port name, and port size from this. port_component_name: str A "sub-component" of the port that we wish to give a unique name in the figures. *Returns* paths_out: List[Path] Each path in this list is a file path for an associated figure. """ # Setup plotting_options = self.plotting_options format = plotting_options.file_format # Compute the figure paths based on the arrangement of the plots # (i.e., if there is one plot per port, or one plot per dimension system = output_port.get_system() system_name = system.get_name() safe_system_name = self.compute_safe_system_name(system_name) port_name = output_port.get_name() log_sink_size = associated_log_sink.get_input_port().size() if plotting_options.plot_arrangement == PortFigureArrangement.OnePlotPerPort: if port_component_name is None: # If there is no sub-component name, then we just # create the file in the main plot directory return [ self.plot_dir / f"system_{safe_system_name}/port_{port_name}.{format}" ] else: # If there is a sub-component name, then we will # create a sub-directory for it return [ self.plot_dir / f"system_{safe_system_name}/port_{port_name}/{port_component_name}.{format}" ] elif plotting_options.plot_arrangement == PortFigureArrangement.OnePlotPerDim: if port_component_name is None: # If there is no sub-component name, then we just # create the files in the main plot directory return [ self.plot_dir / f"system_{safe_system_name}/port_{port_name}/dim_{self.name_of_data_at_index(ii, output_port, associated_log_sink, remove_spaces=True)}.{format}" for ii in range(log_sink_size) ] else: # If there is a sub-component name, then we will # create a sub-directory for it return [ self.plot_dir / f"system_{safe_system_name}/port_{port_name}/{port_component_name}/dim_{self.name_of_data_at_index(ii, output_port, associated_log_sink, remove_spaces=True)}.{format}" for ii in range(log_sink_size) ] else: raise NotImplementedError( f"Invalid plot arrangement for figure naming convention {plotting_options.figure_naming_convention}: {plotting_options.plot_arrangement}." )
[docs] def name_of_data_at_index( self, dim_index: int, target_port: OutputPort, associated_log_sink: VectorLogSink, remove_spaces: bool = False, ) -> str: """ *Description* Returns the name of the data which is in index dim_index of this vector-valued port. TODO(kwesi): Consider moving this to its own utility file, outside of the file manager. *Parameters* self : PortWatcherPlotter The PortWatcherPlotter object. dim_index : int The index of the data in the port. remove_spaces : bool Whether to remove spaces from the name. *Returns* name_of_data: str The name of the data at index dim_index. """ # Setup plotting_options = self.plotting_options n_dims_sink = associated_log_sink.get_input_port().size() system = target_port.get_system() # Input Processing if dim_index >= n_dims_sink: raise ValueError( f"dim_index ({dim_index}) is greater than the number of dimensions in the port ({n_dims_sink})." ) # Default name name = f"Dim #{dim_index}" # If we are using the OnePlotPerPort config, then use the default name if plotting_options.plot_arrangement == PortFigureArrangement.OnePlotPerPort: return name # Otherwise, try to get a better name # - For MultibodyPlants, we can get specific names for certain dimensions of the "state" output port if type(system) is MultibodyPlant: # The multi-body plant has names for specific ports if target_port.get_name() == "state": # We can get the names of the state state_names = system.GetStateNames() name = state_names[dim_index] # Filter our spaces, if requested if remove_spaces: name = name.replace(" ", "_") # Return name! return name
@property def plot_dir(self) -> Path: """ *Description* This function returns the directory where the plots will be saved. *Returns* plot_dir: Path The directory where the plots will be saved. """ return self.base_directory / "plots" @property def raw_data_dir(self) -> Path: """ *Description* This function returns the directory where the raw data will be saved. *Returns* raw_data_dir: Path The directory where the raw data will be saved. """ return self.base_directory / "raw_data"
[docs] def raw_data_file_path( self, system_name: str, port_name: str, port_component_name: str = None ) -> Path: """ *Description* This function returns the file name for saving raw data. *Parameters* port_component_name: str A "sub-component" of the port that we wish to give a unique name in the data. *Returns* raw_data_file_name: Path The file name for saving raw data. """ safe_system_name = self.compute_safe_system_name(system_name) if port_component_name is None: return self.raw_data_dir / f"system_{safe_system_name}_port_{port_name}_data.{self.raw_data_file_format}" else: return self.raw_data_dir / f"system_{safe_system_name}_port_{port_name}" / f"{port_component_name}.{self.raw_data_file_format}"
[docs] def time_data_file_path(self, system_name: str, port_name: str) -> Path: """ *Description* This function returns the file name for saving time data. *Returns* time_data_file_path: Path The file name for saving time data. """ safe_system_name = self.compute_safe_system_name(system_name) return self.raw_data_dir / f"system_{safe_system_name}_port_{port_name}_times.{self.raw_data_file_format}"