Source code for ananke.IntegratedLightDriver

#!/usr/bin/env python
#
# Author: Adrien CR Thob
# Copyright (C) 2022  Adrien CR Thob
#
# This file is part of the py-ananke project,
# <https://github.com/athob/py-ananke>, which is licensed
# under the GNU Affero General Public License v3.0 (AGPL-3.0).
# 
# The full copyright notice, including terms governing use, modification,
# and redistribution, is contained in the files LICENSE and COPYRIGHT,
# which can be found at the root of the source code distribution tree:
# - LICENSE <https://github.com/athob/py-ananke/blob/main/LICENSE>
# - COPYRIGHT <https://github.com/athob/py-ananke/blob/main/COPYRIGHT>
#
"""
Contains the IntegratedLightDriver class definition

Please note that this module is private. The IntegratedLightDriver class is
available in the main ``ananke`` namespace - use that instead.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Callable
from numpy.typing import ArrayLike, NDArray
from functools import cached_property
import numpy as np
import pandas as pd
from scipy import integrate, interpolate
from astropy import constants, units, table

import galaxia_ananke as Galaxia
from galaxia_ananke.Output import vaex
from galaxia_ananke.photometry.Formatting import Formatting as Galaxia_photo_Formatting
from galaxia_ananke._templates import FTTAGS

from ._constants import *
from .utils import classproperty


if TYPE_CHECKING:
    from .Ananke import Ananke
    from galaxia_ananke.photometry.PhotoSystem import PhotoSystem

__all__ = ['IntegratedLightDriver']


[docs] def make_truncated_kroupa_imf(m_min: float = 0.01, m_max: float = 120) -> Tuple[Callable[[ArrayLike], NDArray], float]: # Check given mass range if m_min >= 0.08 or m_max <= 0.5: raise ValueError("Mass range should have a lower bound below 0.08, and a upper bound above 0.5") # Compute normalization constant normalization_constant = ( 10/7*(0.08**0.7-m_min**0.7) + 8/30*(0.08**-0.3-0.5**-0.3) + 4/130*(0.5**-1.3-m_max**-1.3) ) # Compute average mass over mass range average_mass = ( 10/17*(0.08**1.7-m_min**1.7) + 8/70*(0.5**0.7-0.08**0.7) + 4/30*(0.5**-0.3-m_max**-0.3) ) / normalization_constant # Vectorized Kroupa IMF def __kroupa_piecewise_imf(m): m = np.atleast_1d(m) # Ensure input is an array (even if single value) result = np.zeros_like(m) # Create an array to store the results # Define piecewise conditions lower_mask = (m >= m_min) & (m < 0.08) result[lower_mask] = m[lower_mask] ** -0.3 middle_mask = (m >= 0.08) & (m < 0.5) result[middle_mask] = 0.08*m[middle_mask] ** -1.3 upper_mask = (m >= 0.5) & (m <= m_max) result[upper_mask] = 0.04*m[upper_mask] ** -2.3 # Truncate upper/lower bounds truncation_mask = (m < m_min) | (m > m_max) result[truncation_mask] = 0*m[truncation_mask] return result if result.size > 1 else result[0] # Return scalar if input was scalar # Define normalized IMF def imf(m: ArrayLike) -> NDArray: return __kroupa_piecewise_imf(m) / normalization_constant return imf, average_mass, m_min, m_max
[docs] class IntegratedLightDriver: """ TODO """ _mini = Galaxia_photo_Formatting._mini _mact = Galaxia_photo_Formatting._mact _lum = Galaxia_photo_Formatting._lum _teff = Galaxia_photo_Formatting._teff _grav = Galaxia_photo_Formatting._grav _label = Galaxia_photo_Formatting._label _ms_labels = [0,1] _default_imf_utils: Tuple[Callable[[ArrayLike], NDArray], float, float, float] = make_truncated_kroupa_imf() _zero_imf: Callable[[ArrayLike], NDArray] = lambda m: 0*m # _hydrogen_burning_limit = table.QTable({ # https://ui.adsabs.harvard.edu/abs/2023A%26A...671A.119C # Galaxia_photo_Formatting._mini: [(0.075*units.Msun).to(Galaxia_photo_Formatting._unit_mass)], # Galaxia_photo_Formatting._lum: [(-4.19*units.LogUnit(units.Lsun)).to(Galaxia_photo_Formatting._unit_lum)], # Galaxia_photo_Formatting._teff: [(1800*units.Kelvin).to(Galaxia_photo_Formatting._unit_temp)], # Galaxia_photo_Formatting._radi: [(0.083*units.Rsun).to(Galaxia_photo_Formatting._unit_radi)] # })
[docs] def __init__(self, ananke: Ananke, **kwargs: Dict[str, Any]) -> None: """ Parameters ---------- ananke : Ananke object The Ananke object that utilizes this IntegratedLightDriver object **kwargs Additional parameters """ self.__ananke: Ananke = ananke self.__parameters: Dict[str, Any] = kwargs
@classproperty def _default_imf(cls): return cls._default_imf_utils[0] @classproperty def _default_imf_average(cls): return cls._default_imf_utils[1] @classproperty def _default_imf_m_min(cls): return cls._default_imf_utils[2] @classproperty def _default_imf_m_max(cls): return cls._default_imf_utils[3] @property def ananke(self) -> Ananke: return self.__ananke @property def galaxia_output(self) -> Galaxia.Output: return self.ananke._galaxia_output @property def parameters(self) -> Dict[str, Any]: return self.__parameters @cached_property def particle_mass_factors(self) -> NDArray: return self.ananke.particle_masses / self._default_imf_average # TODO remember to replace with particles initial mass @cached_property def particle_total_star_numbers(self) -> NDArray: return np.round(self.galaxia_output.survey.fsample*self.particle_mass_factors).astype(int) @cached_property def particle_nearest_isochrone_tracks(self) -> List[List[table.QTable]]: return [ photo_sys.get_nearest_isochrone_track_for_age_and_metallicity(self.ananke.particle_ages, self.ananke.particle_metallicities) for photo_sys in self.ananke.galaxia_photosystems ] @cached_property def particle_nearest_isochrone_mass_ranges(self) -> List[List[Tuple[float,float]]]: isochrone_tracks: List[List[table.QTable]] = self.particle_nearest_isochrone_tracks return [[tuple(qtable[self._mini][[0,-1]].value) for qtable in qtables ] for qtables in isochrone_tracks] @cached_property def particle_output_mass_ranges(self) -> List[Tuple[float,float]]: magnitude_name = self.galaxia_output.parameter_magnitude_name magnitude_highs = 0.6+np.minimum( self.galaxia_output.parameter_abs_mag_hi, self.galaxia_output.parameter_app_mag_hi-self.ananke.particle_nearest_observed_distmod ) return [ (interpolate.interp1d(*(lambda qt: (qt[magnitude_name].physical,qt[self._mini]))(track[np.where(track[magnitude_name].value<=mag_hi)[0][0]-1+np.arange(2)]))(units.Magnitude(mag_hi).physical), interpolate.interp1d(*(lambda qt: (qt[magnitude_name].physical,qt[self._mini]))(track[np.where(track[magnitude_name].value<=mag_hi)[0][-1]+np.arange(2)]))(units.Magnitude(mag_hi).physical) if track[magnitude_name][-1].value>mag_hi else track[magnitude_name][-1].value) for track, mag_hi in zip(self.particle_nearest_isochrone_tracks[0], magnitude_highs) ] @cached_property def particle_dimensionless_flux_interpolators_of_nearest_isochrone_tracks(self) -> pd.DataFrame: # TODO could typing hint to DataFrame[interpolate.interp1d] ? pandera? isochrone_tracks: List[List[table.QTable]] = self.particle_nearest_isochrone_tracks unique_tracks_index_and_inverse: List[Tuple[NDArray,NDArray]] = [ np.unique(list(map(id, qtables)), return_inverse=True, return_index=True)[1:] for qtables in isochrone_tracks ] unique_dimensionless_flux_interpolators: List[pd.DataFrame] = [ # TODO could typing hint to DataFrame[interpolate.interp1d] ? pandera? pd.DataFrame([ self.make_dimensionless_flux_interpolators_from_qtable(qtables[i], photo_sys) for i in index ]) for (index,_), qtables, photo_sys in zip(unique_tracks_index_and_inverse, isochrone_tracks, self.ananke.galaxia_photosystems) ] return pd.concat([ flux_interpolators.loc[inverse].reset_index(drop=True) for (_,inverse),flux_interpolators in zip(unique_tracks_index_and_inverse, unique_dimensionless_flux_interpolators) ], axis=1) @cached_property def __particle_integrated_photometry(self) -> table.QTable: dimensionless_flux_interpolators: pd.DataFrame = self.particle_dimensionless_flux_interpolators_of_nearest_isochrone_tracks imf: Callable[[ArrayLike], NDArray] = self._default_imf particle_mass_factor: NDArray = self.particle_mass_factors dimensionless_integrated_fluxes: pd.DataFrame = dimensionless_flux_interpolators.apply( lambda series: particle_mass_factor*series.map( lambda interp1d: self.integrate_flux_interpolator(interp1d, imf) ) ) return table.QTable({key: series.to_numpy()*zeropoint for (key,series),zeropoint in zip(dimensionless_integrated_fluxes.items(),self.ananke.photosystems_zeropoints)}) @cached_property def particle_residual_distribution_functions(self) -> List[Callable[[ArrayLike], NDArray]]: actual_magnames: List[str] = list(self.ananke.galaxia_catalogue_mag_names) sub_vaex: vaex.DataFrame = self.galaxia_output[~self.galaxia_output['|'.join(map(lambda c: f"({c}!={c})", actual_magnames))]][[self.galaxia_output._parentid, self.galaxia_output._mini]] sub_vaex_groupby_parentid: vaex.GroupBy = sub_vaex.groupby(self.galaxia_output._parentid) imf: Callable[[ArrayLike], NDArray] = self._default_imf sample_number_imfs: List[Callable[[ArrayLike], NDArray]] = [ # TODO this step may be highly inefficient, and vaex may have better tools self.estimate_number_distribution_function( sub_vaex_groupby_parentid.get_group(i)[self.galaxia_output._mini].to_numpy(), mass_range, lambda m: total_star_number*imf(m), (self._default_imf_m_min, self._default_imf_m_max) ) if (i,) in sub_vaex_groupby_parentid.groups else self._zero_imf for i, mass_range, total_star_number in zip(self.ananke.particle_parentids, self.particle_output_mass_ranges, self.particle_total_star_numbers) ] return [lambda m: np.clip(imf(m) - sample_number_imf(m)/total_star_number, 0, np.inf) for sample_number_imf, total_star_number in zip(sample_number_imfs, self.particle_total_star_numbers)] @cached_property def particle_residual_photometry(self) -> table.QTable: dimensionless_flux_interpolators: pd.DataFrame = self.particle_dimensionless_flux_interpolators_of_nearest_isochrone_tracks residual_imfs: List[Callable[[ArrayLike], NDArray]] = self.particle_residual_distribution_functions particle_mass_factor: NDArray = self.particle_mass_factors dimensionless_integrated_fluxes: pd.DataFrame = dimensionless_flux_interpolators.transform( lambda series: list(zip(series, residual_imfs)) ).apply( lambda series: particle_mass_factor*series.map( lambda interp1d_imf: self.integrate_flux_interpolator(*interp1d_imf) ) ) return table.QTable({key: series.to_numpy()*zeropoint for (key,series),zeropoint in zip(dimensionless_integrated_fluxes.items(),self.ananke.photosystems_zeropoints)}) @cached_property def particle_integrated_photometry(self) -> table.QTable: isochrone_tracks: List[List[table.QTable]] = self.particle_nearest_isochrone_tracks unique_tracks_index_and_inverse: List[Tuple[NDArray,NDArray]] = [ np.unique(list(map(id, qtables)), return_inverse=True, return_index=True)[1:] for qtables in isochrone_tracks ] unique_tracks: List[List[table.QTable]] = [ [qtables[i] for i in index] for (index,_),qtables in zip(unique_tracks_index_and_inverse, isochrone_tracks) ] imf: Callable[[ArrayLike], NDArray] = self._default_imf imf_average: float = self._default_imf_average unique_integrated_photometry: List[table.QTable] = [ table.vstack([self.integrate_qtable(qtable, photo_sys, imf) for qtable in qtables]) for qtables, photo_sys in zip(unique_tracks, self.ananke.galaxia_photosystems) ] particle_mass_factor = self.particle_mass_factors return table.hstack([ table.QTable({key: particle_mass_factor * column for key,column in qtable[inverse].items()}) for (_,inverse),qtable in zip(unique_tracks_index_and_inverse, unique_integrated_photometry) ]) @cached_property def __particle_residual_photometry(self) -> table.QTable: intrinsic_magnames: List[str] = list(self.ananke.intrinsic_catalogue_mag_names) actual_magnames: List[str] = list(self.ananke.galaxia_catalogue_mag_names) sub_vaex = self.galaxia_output[~self.galaxia_output['|'.join(map(lambda c: f"({c}!={c})", actual_magnames))]][[self.galaxia_output._parentid]+intrinsic_magnames] lambda_function = lambda m: units.Magnitude(m).physical for i_mag, a_mag in zip(intrinsic_magnames, actual_magnames): sub_vaex[a_mag] = sub_vaex[i_mag].apply(lambda_function, vectorize=True) summed_df_per_parentid = sub_vaex.groupby(self.galaxia_output._parentid).agg( dict(zip(actual_magnames, len(actual_magnames)*['sum'])) ).to_pandas_df().set_index(self.galaxia_output._parentid) summed_df_per_parentid = pd.concat([ summed_df_per_parentid, pd.DataFrame({},index=list(set(self.ananke.particle_parentids)-set(summed_df_per_parentid.index))) ]).fillna(0).loc[self.ananke.particle_parentids] particle_catalogue_photometry = table.QTable({ magname: summed.to_numpy()*zeropoint for (magname,summed),zeropoint in zip(summed_df_per_parentid.items(), self.ananke.photosystems_zeropoints) }) fsample = self.galaxia_output.survey.fsample return table.QTable({ magname: fsample*integrated-catalogue for (magname,integrated),(_,catalogue) in zip(self.particle_integrated_photometry.items(),particle_catalogue_photometry.items()) })
[docs] @classmethod def integrate_qtable(cls, qtable: table.QTable, photo_sys: PhotoSystem, imf: Callable[[ArrayLike], NDArray]) -> table.QTable: flux_interpolators = cls.make_dimensionless_flux_interpolators_from_qtable(qtable, photo_sys) return table.QTable({mag_name: zeropoint*[cls.integrate_flux_interpolator(flux_interpolator, imf)] for (mag_name,flux_interpolator),zeropoint in zip(flux_interpolators.items(), photo_sys.zeropoints)})
[docs] @classmethod def make_dimensionless_flux_interpolators_from_qtable( cls, qtable: table.QTable, photo_sys: PhotoSystem ) -> Dict[str,interpolate.interp1d]: return {photo_sys.mag_names_to_export_keys_mapping[mag_name]: interpolate.interp1d(qtable[cls._mini].to(units.Msun).value, magnitude.physical[np.newaxis]) for mag_name, magnitude in qtable[photo_sys.mag_names].items()}
[docs] @staticmethod def integrate_flux_interpolator(flux_interpolator: interpolate.interp1d, imf: Callable[[ArrayLike], NDArray]) -> float: return integrate.quad(lambda m: imf(m)*flux_interpolator(m), flux_interpolator.x[0], flux_interpolator.x[-1])[0]
[docs] @staticmethod def estimate_number_distribution_function(masses: NDArray, mass_range: Tuple[float, float], number_imf: Callable[[ArrayLike], NDArray], imf_range: Tuple[float, float], n_windows: int = 20) -> interpolate.interp1d: # clean up mass range mass_range = np.sort(mass_range) masses = np.sort(masses) mass_range[0] = min(mass_range[0], np.nextafter(masses[0],0)) mass_range[1] = max(mass_range[1], np.nextafter(masses[-1],np.inf)) # make initial stepwise distribution function estimate window = int(np.ceil(masses.shape[0]/n_windows)) indices = np.unique(np.clip(masses.shape[0]//2 + window*(np.arange(n_windows+1) - n_windows//2), 0, masses.shape[0]-1)) distribution_masses = np.hstack([mass_range[0], masses[indices], mass_range[1]]) mass_diff = np.diff(distribution_masses) distribution = np.hstack([0.5,np.diff(indices),0.5])/mass_diff # get steps centroids and find number IMF values distribution_mass_centroids = (distribution_masses[:-1]+distribution_masses[1:])/2 number_imf_at_centroids = number_imf(distribution_mass_centroids) # from peak to trough, assure estimate at centroids is under number IMF peak_to_trough = np.argsort((np.log10(number_imf_at_centroids) - np.log10(distribution))*mass_diff) for i in range(peak_to_trough.shape[0]-1): current = peak_to_trough[i] nextone = peak_to_trough[i+1] residual = number_imf_at_centroids[current]-distribution[current] if residual < 0: # TODO redistribution could take into account proximity of remaining mass bins distribution[current] -= np.abs(residual) distribution[nextone] += np.abs(residual)*mass_diff[current]/mass_diff[nextone] # refine estimate by shifting from stepwise to linear, conserving integral masses = distribution_masses[1:-1] distribution_masses = np.hstack([distribution_masses[[0,-1]], distribution_mass_centroids, masses]) temp = np.argsort(distribution_masses) distribution_masses = np.hstack([mass_range[0],distribution_masses[temp],mass_range[1]]) distribution = np.hstack([0,np.hstack([distribution[[0,-1]],distribution,((masses-distribution_mass_centroids[:-1])*distribution[:-1] + (distribution_mass_centroids[1:]-masses)*distribution[1:])/np.diff(distribution_mass_centroids)])[temp],0]) # filling zeros and applying interpolation distribution_masses = np.hstack([imf_range[0], distribution_masses, imf_range[1]]) distribution = np.hstack([0, distribution, 0]) return interpolate.interp1d(distribution_masses, distribution)
if __name__ == '__main__': raise NotImplementedError()