from abc import ABC, abstractmethod
import os
import h5py
import logging
import astropy.units as u
from typing import List, Union, Optional
from rubix import config
from rubix.logger import get_logger
[docs]
def create_rubix_galaxy(
file_path: str,
particle_data: dict,
galaxy_data: dict,
simulation_metadata: dict,
units: dict,
config: dict,
logger: logging.Logger,
):
"""
Create a Rubix file with the given data.
Args:
file_path (str): Path to save the Rubix file.
particle_data (dict): Dictionary containing the particle data.
galaxy_data (dict): Dictionary containing the galaxy data.
simulation_metadata (dict): Dictionary containing the simulation metadata.
units (dict): Dictionary containing the units.
config (dict): Dictionary containing the configuration.
logger (logging.Logger): Logger object to log messages.
Returns:
None
"""
logger.debug("Creating Rubix file at path: %s", file_path)
with h5py.File(file_path, "w") as f:
# Create groups
meta_group = f.create_group("meta")
galaxy_group = f.create_group("galaxy")
particle_group = f.create_group("particles")
# Save the simulation metadata
for key, value in simulation_metadata.items():
meta_group.create_dataset(key, data=value)
# Save the galaxy data: Create a dataset for each field and add the units as attributes
for key, value in galaxy_data.items():
logger.debug(
f"Converting {key} for galaxy data into {config['galaxy'][key]}"
)
value = u.Quantity(value, units["galaxy"][key]).to(config["galaxy"][key])
galaxy_group.create_dataset(key, data=value)
galaxy_group[key].attrs["unit"] = config["galaxy"][key]
# Save the particle data: Create a dataset for each field and add the units as attributes
for key in particle_data:
particle_group.create_group(key)
for field, value in particle_data[key].items():
logger.debug(
f"Converting {field} for particle type {key} into {config['particles'][key][field]}"
)
value = u.Quantity(value, units[key][field]).to(
config["particles"][key][field]
)
particle_group[key].create_dataset(field, data=value) # type: ignore
particle_group[key][field].attrs["unit"] = config["particles"][key][field] # type: ignore
logger.info(f"Rubix file saved at {file_path}")
[docs]
class BaseHandler(ABC):
"""
Base class for handling input data and converting it to Rubix format.
Args:
config (dict): Configuration for the BaseHandler.
_logger (logging.Logger): Logger object to log messages.
"""
def __init__(self, logger_config=None):
"""Initializes the BaseHandler class"""
self.config = config["BaseHandler"]
self._logger = get_logger(logger_config)
[docs]
@abstractmethod
def get_particle_data(self) -> dict:
"""Returns the particle data in the required format"""
[docs]
@abstractmethod
def get_galaxy_data(self) -> dict:
"""Returns the galaxy data in the required format"""
[docs]
@abstractmethod
def get_units(self) -> dict:
"""Returns the units in the required format"""
[docs]
def to_rubix(self, output_path: str):
"""
Converts the input data to Rubix format and saves it to the output path.
Args:
output_path (str): Path to save the Rubix file.
"""
self._logger.debug("Converting to Rubix format..")
os.makedirs(output_path, exist_ok=True)
# Get the data
particle_data = self.get_particle_data()
galaxy_data = self.get_galaxy_data()
simulation_metadata = self.get_simulation_metadata()
# Get the units
units = self.get_units()
# Check if the input data is valid and in the correct format
self._check_data(particle_data, galaxy_data, simulation_metadata, units)
# Create the Rubix h5 file
file_path = os.path.join(output_path, "rubix_galaxy.h5")
self._logger.info(f"Rubix file saved at {file_path}")
# Create the Rubix file
create_rubix_galaxy(
file_path,
particle_data,
galaxy_data,
simulation_metadata,
units,
self.config,
self._logger,
)
def _check_data(self, particle_data, galaxy_data, simulation_metadata, units):
# Check if all required fields are present
self._check_galaxy_data(galaxy_data, units)
self._check_particle_data(particle_data, units)
self._check_simulation_metadata(simulation_metadata)
def _check_simulation_metadata(self, simulation_metadata):
"""Check if all required fields are present in the simulation metadata
Currently we do not have any required fields to check here.
"""
def _check_galaxy_data(self, galaxy_data, units):
# Check if all required fields are present
for field in self.config["galaxy"]:
if field not in galaxy_data:
raise ValueError(f"Missing field {field} in galaxy data")
# Check if the units are correct
for field in galaxy_data:
if field not in units["galaxy"]:
raise ValueError(f"Units for {field} not found in units")
"""
def _check_particle_data(self, particle_data, units):
# Check if all required fields are present
for key in self.config["particles"]:
if key not in particle_data:
raise ValueError(f"Missing particle type {key} in particle data")
for field in self.config["particles"][key]:
if field not in particle_data[key]:
raise ValueError(
f"Missing field {field} in particle data for particle type {key}"
)
# Check if the units are correct
for key in particle_data:
for field in particle_data[key]:
if field not in units[key]:
raise ValueError(f"Units for {field} not found in units")
"""
def _check_particle_data(self, particle_data, units):
# Get the list of expected particle types from the configuration
expected_particle_types = list(self.config["particles"].keys())
# Find the particle types that are actually present in the particle data
present_particle_types = [
key for key in expected_particle_types if key in particle_data
]
# If none of the expected particle types are present, raise a ValueError
if not present_particle_types:
raise ValueError(
f"None of the expected particle types {expected_particle_types} are present in particle data"
)
# For each particle type that is present, check that all required fields are present
for particle_type in present_particle_types:
required_fields = self.config["particles"][particle_type]
for field in required_fields:
if field not in particle_data[particle_type]:
raise ValueError(
f"Missing field {field} in particle data for particle type {particle_type}"
)
# Now check if units are specified for all fields in particle_data[particle_type]
particle_units = units.get(particle_type, {})
for field in particle_data[particle_type]:
if field not in particle_units:
raise ValueError(
f"Units for {field} not found in units for particle type {particle_type}"
)