#!/usr/bin/env python
"""
Contains the ErrorModelDriver class definition
Please note that this module is private. The ErrorModelDriver class is
available in the main ``ananke`` namespace - use that instead.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Set, List, Dict, Callable
from numpy.typing import ArrayLike, NDArray
from warnings import warn
from collections.abc import Iterable
import numpy as np
import pandas as pd
from . import utils
from ._default_error_model import *
from ._constants import *
if TYPE_CHECKING:
from .Ananke import Ananke
import Galaxia_ananke as Galaxia
__all__ = ['ErrorModelDriver']
[docs]
class ErrorModelDriver:
"""
Proxy to the utilities for given error model driver parameters.
"""
_sigma_formatter = '{}_Sig'
_sigma_template = _sigma_formatter.format
_error_formatter = '{}_Err'
_error_template = _error_formatter.format
_extra_output_keys = ()
[docs]
def __init__(self, ananke: Ananke, **kwargs: Dict[str, Any]) -> None:
"""
Parameters
----------
ananke : Ananke object
The Ananke object that utilizes this ErrorModel object
error_model : function [df --> dict(prop: coefficient)]
Use to specify a model that returns error's standard deviations
per property from characterisitics of the mock star given in a
dataframe format. The function must return the standard deviations
per property in a dictionary format with keys corresponding to
the property names returned by Galaxia (use property
galaxia_catalogue_mag_and_astrometrics of the Ananke object).
By default, the class will query the chosen photometric system
to check if it has a default model to use. If it doesn't find one
it will simply fill errors with nan values.
"""
self.__ananke: Ananke = ananke
self.__parameters: Dict[str, Any] = kwargs
self._test_error_model()
def __getattr__(self, item):
if (item in self.ananke.__dir__() and item.startswith('particle')):
return getattr(self.ananke, item)
else:
return self.__getattribute__(item)
@property
def ananke(self) -> Ananke:
return self.__ananke
@property
def galaxia_output(self) -> Galaxia.Output:
return self.ananke._galaxia_output
@staticmethod
def _expand_and_apply_error_model(df, error_model) -> Dict[str, ArrayLike]:
if not isinstance(error_model, Iterable):
error_model = [error_model]
return {key: error for error_dict in [(err_model(df) if callable(err_model) else err_model) for err_model in error_model] for key,error in error_dict.items()} # TODO adapt to dataframe type of output?
def _test_error_model(self) -> None:
dummy_df = utils.RecordingDataFrame([], columns = self.ananke.galaxia_catalogue_keys + self._extra_output_keys) # TODO make use of dummy_df.record_of_all_used_keys
dummy_df.loc[0] = np.nan
try:
dummy_err = self._expand_and_apply_error_model(dummy_df, self.error_model)
except KeyError as KE:
raise KE # TODO make it more informative
utils.compare_given_and_required(dummy_err.keys(), set(), self.ananke.galaxia_catalogue_mag_and_astrometrics, error_message="Given error model function returns wrong set of keys")
@property
def _sigma_keys(self) -> Set[str]:
return set(map(self._sigma_template, self.ananke.galaxia_catalogue_mag_names))
@property
def _error_keys(self) -> Set[str]:
return set(map(self._error_template, self.ananke.galaxia_catalogue_mag_names))
@classmethod
def __pp_pipeline(cls, df: pd.DataFrame, error_keys: Set[str],
error_model: List[Union[Callable[[pd.DataFrame],
Dict[str, NDArray]],
Dict[str, float]]]) -> None:
if error_keys.difference(df.columns):
for prop_name, error in cls._expand_and_apply_error_model(df, error_model).items():
# pre-generate the keys to use for the standard error and its actual gaussian drawn error of property prop_name
prop_sig_name, prop_err_name = cls._sigma_template(prop_name), cls._error_template(prop_name)
# assign the column of the standard error values for property prop_name in the final catalogue output
df[prop_sig_name] = error
# assign the column of the actual gaussian drawn error values for property prop_name in the final catalogue output
df[prop_err_name] = error*np.random.randn(df.shape[0])
# add the drawn error value to the existing quantity for property prop_name
df[prop_name] += df[prop_err_name]
# with_columns.append(prop_name)
@property
def errors(self): # TODO figure out output typing
galaxia_output = self.galaxia_output
error_keys = self._error_keys
galaxia_output.apply_post_process_pipeline_and_flush(self.__pp_pipeline, error_keys, self.error_model, flush_with_columns=tuple(self._error_prop_names))
galaxia_output._pp_convert_icrs_to_galactic()
return galaxia_output[list(error_keys)]
@property
def _error_prop_names(self) -> Set[str]:
dummy_df = utils.RecordingDataFrame([], columns = self.ananke.galaxia_catalogue_keys + self._extra_output_keys) # TODO make use of dummy_df.record_of_all_used_keys
dummy_df.loc[0] = np.nan
return set(self._expand_and_apply_error_model(dummy_df, self.error_model).keys())
@property
def parameters(self) -> Dict[str, Any]:
return self.__parameters
@property
def ignore(self) -> bool:
return self.parameters.get('ignore', False)
@property
def error_model(self) -> List[Union[Callable[[pd.DataFrame], Dict[str, NDArray]], Dict[str, float]]]: # TODO design
return self.parameters.get('error_model', [getattr(psys, 'default_error_model', self.__missing_default_error_model_for_photosystem(psys)) for psys in self.ananke.galaxia_photosystems])
@staticmethod
def __missing_default_error_model_for_photosystem(photosystem) -> Callable[[pd.DataFrame], Dict[str, NDArray]]:
def __return_zero_error_and_warn(df):
warn(f"Method default_error_model isn't defined for photometric system {photosystem.key}", UserWarning, stacklevel=2)
return {mag: np.zeros(df.shape[0]) for mag in photosystem.to_export_keys}
return __return_zero_error_and_warn
@staticmethod
def __missing_default_error_model_for_isochrone(photosystem):
warn('This static method will be deprecated, please use instead static method __missing_default_error_model_for_photosystem', DeprecationWarning, stacklevel=2)
return ErrorModelDriver.__missing_default_error_model_for_photosystem(photosystem)
if __name__ == '__main__':
raise NotImplementedError()