Source code for galight.fitting_specify

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 14 12:16:43 2020

@author: Xuheng Ding
"""

import numpy as np
import copy
import lenstronomy
from packaging import version
import warnings

[docs]class FittingSpecify(object): """ A class to generate the materials for the 'FittingSequence', defined by 'lenstronomy' key materials include the following, which are prepared by 'prepare_fitting_seq()': - kwargs_data_joint: data materils - kwargs_model: a list of class - kwargs_constraints - kwargs_likelihood - kwargs_params - imageModel """ def __init__(self, data_process_class, sersic_major_axis=True): self.data_process_class = data_process_class self.deltaPix = data_process_class.deltaPix self.numPix = len(self.data_process_class.target_stamp) self.zp = data_process_class.zp self.apertures = copy.deepcopy(data_process_class.apertures) self.mask_apertures = copy.deepcopy(data_process_class.mask_apertures) self.header = copy.deepcopy(data_process_class.header) self.target_pos = copy.deepcopy(data_process_class.target_pos) self.segm_deblend = np.array(data_process_class.segm_deblend) if sersic_major_axis is None: if version.parse(lenstronomy.__version__) >= version.parse("1.9.0"): from lenstronomy.Conf import config_loader convention_conf = config_loader.conventions_conf() self.sersic_major_axis = convention_conf['sersic_major_axis'] #If sersic_major_axis == None, the sersic_major_axis follows Lenstronomy. else: self.sersic_major_axis = sersic_major_axis
[docs] def sepc_kwargs_data(self, supersampling_factor = 2, psf_data = None, psf_error_map = None): import lenstronomy.Util.simulation_util as sim_util kwargs_data = sim_util.data_configure_simple(self.numPix, self.deltaPix, inverse=True) #inverse: if True, coordinate system is ra to the left, if False, to the right kwargs_data['image_data'] = self.data_process_class.target_stamp kwargs_data['noise_map'] = self.data_process_class.noise_map if psf_data is None: psf_data = self.data_process_class.PSF_list[self.data_process_class.psf_id_for_fitting] kwargs_psf = {'psf_type': 'PIXEL', 'kernel_point_source': psf_data,'pixel_size': self.deltaPix} if psf_error_map is not None: kwargs_psf['psf_error_map'] = psf_error_map # here we super-sample the resolution of some of the pixels where the surface brightness profile has a high gradient supersampled_indexes = np.zeros((self.numPix, self.numPix), dtype=bool) supersampled_indexes[int(self.numPix/2)-int(self.numPix/10):int(self.numPix/2)+int(self.numPix/10), int(self.numPix/2)-int(self.numPix/10):int(self.numPix/2)+int(self.numPix/10)] = True kwargs_numerics = {'supersampling_factor': supersampling_factor, 'compute_mode': 'adaptive', 'supersampled_indexes': supersampled_indexes} # kwargs_numerics = {'supersampling_factor': supersampling_factor} image_band = [kwargs_data, kwargs_psf, kwargs_numerics] multi_band_list = [image_band] self.kwargs_data = kwargs_data self.kwargs_psf = kwargs_psf self.kwargs_numerics = kwargs_numerics self.kwargs_data_joint = {'multi_band_list': multi_band_list, 'multi_band_type': 'multi-linear'} # 'single-band', 'multi-linear', 'joint-linear'
[docs] def sepc_kwargs_model(self, extend_source_model = ['SERSIC_ELLIPSE'] * 1, point_source_num = 1): point_source_list = ['UNLENSED'] * point_source_num kwargs_model = {'point_source_model_list': point_source_list} if extend_source_model != None and extend_source_model != []: light_model_list = extend_source_model kwargs_model['lens_light_model_list'] = light_model_list else: light_model_list = [] self.point_source_list = point_source_list self.light_model_list = light_model_list kwargs_model['sersic_major_axis'] = self.sersic_major_axis self.kwargs_model = kwargs_model
[docs] def sepc_kwargs_constraints(self, fix_center_list = None): """ Prepare the 'kwargs_constraints' for the fitting. Parameter -------- fix_center_list: list. -if not None, describe how to fix the center [[0,0]] for example. This list defines how to 'joint_lens_light_with_point_source' definied by lenstronomy: [[i_point_source, k_lens_light], [...], ...], see https://lenstronomy.readthedocs.io/en/latest/_modules/lenstronomy/Sampling/parameters.html?highlight=joint_lens_light_with_point_source# for example [[0, 1]], joint first (0) point source with the second extend source (1). """ kwargs_constraints = {'num_point_source_list': [1] * len(self.point_source_list) #kwargs_constraints also generated here } if fix_center_list is not None: kwargs_constraints['joint_lens_light_with_point_source'] = fix_center_list self.kwargs_constraints = kwargs_constraints
[docs] def sepc_kwargs_likelihood(self, condition=None): """ Prepare the 'kwargs_likelihood' for the fitting. Most default values will be assigned. Parameter -------- condition: input as a defination. Set up extra prior. For example if one want the first component have lower Sersic index, it can be set by first define a condition: def condition_def(kwargs_lens, kwargs_source, kwargs_lens_light, kwargs_ps, kwargs_special, kwargs_extinction): logL = 0 cond_0 = (kwargs_source[0]['n_sersic'] > kwargs_source[1]['n_sersic']) if cond_0: logL -= 10**15 return logL Then assign to condition: fit_sepc.prepare_fitting_seq(**, condition = condition_def) """ kwargs_likelihood = {'check_bounds': True, #Set the bonds, if exceed, reutrn "penalty" 'image_likelihood_mask_list': [self.data_process_class.target_mask], 'custom_logL_addition': condition } if self.light_model_list != []: kwargs_likelihood['source_marg'] = False #In likelihood_module.LikelihoodModule -- whether to fully invert the covariance matrix for marginalization kwargs_likelihood['check_positive_flux'] = True #penalty is any component's flux is 'negative'. self.kwargs_likelihood = kwargs_likelihood
[docs] def sepc_kwargs_params(self, source_params = None, fix_n_list = None, fix_Re_list = None, ps_params = None, ps_pix_center_list= None, neighborhood_size = 4, threshold = 5, apertures_center_focus = False): """ Setting up the 'kwargs_params' (i.e., the parameters) for the fitting. If 'source_params' or 'ps_params' are given, rather then setting as None, then, the input settings will be used. Parameter -------- fix_n_list: list. Describe a prior if want to fix the Sersic index. e.g., fix_n_list= [[0,4], [1,1]], means the first (i.e., 0) fix n = 4; the second (i.e., 1) fix n = 1. fix_Re_list: list. Describe a prior if want to fix the Sersic effective radius. e.g., fix_n_list= [[0,0.4], [1,1]], means the first (i.e., 0) fix Reff value as 0.4. apertures_center_focus: bool. If true, the default parameters will have strong prior so that the center of the fitted Sersic will be closer to the apertures. """ kwargs_params = {} if self.light_model_list != []: if source_params is None: source_params = source_params_generator(frame_size = self.numPix, apertures = self.apertures, deltaPix = self.deltaPix, fix_n_list = fix_n_list, fix_Re_list = fix_Re_list, apertures_center_focus = apertures_center_focus) else: source_params = source_params kwargs_params['lens_light_model'] = source_params if ps_params is None and len(self.point_source_list) > 0: if ps_pix_center_list is None: from galight.tools.measure_tools import find_loc_max x, y = find_loc_max(self.data_process_class.target_stamp, neighborhood_size = neighborhood_size, threshold = threshold) #Automaticlly find the local max as PS center. # if x == []: if len(x) < len(self.point_source_list): x, y = find_loc_max(self.data_process_class.target_stamp, neighborhood_size = neighborhood_size, threshold = threshold/2) #Automaticlly find the local max as PS center. # raise ValueError("Warning: could not find the enough number of local max to match the PS numbers. Thus,\ # the ps_params must input manually or change the neighborhood_size and threshold values") if len(x) < len(self.point_source_list): warnings.warn("\nWarning: could not find the enough number of local max to match the PS numbers. Thus, all the initial PS set the same initial parameters.") if x == []: x, y = [self.numPix/2], [self.numPix/2] else: x = x * len(self.point_source_list) y = y * len(self.point_source_list) flux_ = [] for i in range(len(x)): flux_.append(self.data_process_class.target_stamp[int(x[i]), int(y[i])]) _id = np.flipud(np.argsort(flux_)) arr_x = np.array(x) arr_y = np.array(y) ps_x = - 1 * ((arr_x - int(self.numPix/2) ) ) ps_y = (arr_y - int(self.numPix/2) ) center_list = [] flux_list = [] for i in range(len(self.point_source_list)): center_list.append([ps_x[_id[i]], ps_y[_id[i]]]) flux_list.append(flux_[_id[i]] * 10 ) elif ps_pix_center_list is not None: if len(ps_pix_center_list) != len(self.point_source_list): raise ValueError("Point source number mismatch between ps_pix_center_list and point_source_num") center_list = ps_pix_center_list for i in range(len(center_list)): center_list[i][0] = -center_list[i][0] ps_params = ps_params_generator(centers = center_list, deltaPix = self.deltaPix) else: ps_params = ps_params kwargs_params['point_source_model'] = ps_params center_pix_pos = [] if len(self.point_source_list) > 0: for i in range(len(ps_params[0])): x = -1 * ps_params[0][i]['ra_image'][0]/self.deltaPix y = ps_params[0][i]['dec_image'][0]/self.deltaPix center_pix_pos.append([x, y]) center_pix_pos = np.array(center_pix_pos) center_pix_pos = center_pix_pos + int(self.numPix/2) self.center_pix_pos = center_pix_pos self.kwargs_params = kwargs_params
[docs] def sepc_imageModel(self, sersic_major_axis): from lenstronomy.ImSim.image_model import ImageModel from lenstronomy.Data.imaging_data import ImageData from lenstronomy.Data.psf import PSF data_class = ImageData(**self.kwargs_data) from lenstronomy.PointSource.point_source import PointSource pointSource = PointSource(point_source_type_list=self.point_source_list) psf_class = PSF(**self.kwargs_psf) from lenstronomy.LightModel.light_model import LightModel try: lightModel = LightModel(light_model_list=self.light_model_list, sersic_major_axis=sersic_major_axis) # By this setting: fit_sepc.lightModel.func_list[1]._sersic_major_axis except: lightModel = LightModel(light_model_list=self.light_model_list) if version.parse(lenstronomy.__version__) >= version.parse("1.9.0"): warnings.warn("\nWarning: The current Lenstronomy Version doesn't not allow for sersic_major_axis=True. Please update you Lenstrnomy version or change you Lenstronomy configure file.") if self.light_model_list is None: imageModel = ImageModel(data_class, psf_class, point_source_class=pointSource, kwargs_numerics=self.kwargs_numerics) else: imageModel = ImageModel(data_class, psf_class, lens_light_model_class=lightModel, point_source_class=pointSource, kwargs_numerics=self.kwargs_numerics) self.data_class = data_class self.psf_class = psf_class self.lightModel = lightModel self.imageModel = imageModel self.pointSource = pointSource
[docs] def plot_fitting_sets(self, savename = None, show_plot=True): """ To make a plot show how the data will be fitted. The extend source will be shown using aperture, point source will be show as point source. Parameter -------- savename: None or string. -Defining the saving name. show_plot: bool. -Plot or not plot. Note that figure can be saved without shown. """ from galight.tools.measure_tools import plot_data_apertures_point plot_data_apertures_point(self.kwargs_data['image_data'] * self.kwargs_likelihood['image_likelihood_mask_list'][0], # + (self.kwargs_likelihood['image_likelihood_mask_list'][0]==0)*1.e6 , self.apertures, self.center_pix_pos, savename = savename, show_plot=show_plot)
[docs] def prepare_fitting_seq(self, supersampling_factor = 2, psf_data = None, extend_source_model = None, point_source_num = 0, ps_pix_center_list = None, fix_center_list = None, source_params = None, fix_n_list = None, fix_Re_list = None, ps_params = None, condition = None, neighborhood_size = 4, threshold = 5, apertures_center_focus = False, psf_error_map = None, mpi = False): """ Key function used to prepared for the fitting. Parameters will be passed to the corresponding functions. """ self.mpi = mpi if extend_source_model is None: extend_source_model = ['SERSIC_ELLIPSE'] * len(self.apertures) self.sepc_kwargs_data(supersampling_factor = supersampling_factor, psf_data = psf_data, psf_error_map = psf_error_map) self.sepc_kwargs_model(extend_source_model = extend_source_model, point_source_num = point_source_num) self.sepc_kwargs_constraints(fix_center_list = fix_center_list) self.sepc_kwargs_likelihood(condition) self.sepc_kwargs_params(source_params = source_params, fix_n_list = fix_n_list, fix_Re_list = fix_Re_list, ps_params = ps_params, neighborhood_size = neighborhood_size, threshold = threshold, apertures_center_focus = apertures_center_focus, ps_pix_center_list = ps_pix_center_list) if point_source_num == 0 or point_source_num == None: del self.kwargs_params['point_source_model'] del self.kwargs_constraints['num_point_source_list'] del self.kwargs_model['point_source_model_list'] self.sepc_imageModel(sersic_major_axis = self.sersic_major_axis) print("The settings for the fitting is done. Ready to pass to FittingProcess. \n However, please make updates manullay if needed.")
[docs] def build_fitting_seq(self): from lenstronomy.Workflow.fitting_sequence import FittingSequence self.fitting_seq = FittingSequence(self.kwargs_data_joint, self.kwargs_model, self.kwargs_constraints, self.kwargs_likelihood, self.kwargs_params, mpi=self.mpi)
# return fitting_seq, self.imageModel
[docs]def source_params_generator(frame_size, apertures = [], deltaPix = 1, fix_n_list = None, fix_Re_list = None, apertures_center_focus = False): """ Quickly generate a source parameters for the fitting. Parameter -------- frame_size: int. The frame size, to define the center of the frame apertures: The apertures of the targets deltaPix: The pixel size of the data fix_n_list: A list to define how to fix the sersic index, default = [] -for example: fix_n_list = [[0,1],[1,4]], fix first and disk and second as bulge. apertures_center_focus: If True, the prior of the Sersic postion will be most limited to the center of the aperture. Return -------- A Params list for the fitting. """ import lenstronomy.Util.param_util as param_util fixed_source = [] kwargs_source_init = [] kwargs_source_sigma = [] kwargs_lower_source = [] kwargs_upper_source = [] center = int(frame_size/2) for i in range(len(apertures)): aper = apertures[i] Reff = aper.a * deltaPix q = aper.b/aper.a phi = - aper.theta # since data_configure_simple(inverse=True), aperture is anti-clock-wise, and inverse=True means lenstronomy is clock-wise e1, e2 = param_util.phi_q2_ellipticity(phi, q) if isinstance(apertures[0].positions[0],float): pos_x, pos_y = aper.positions[0], aper.positions[1] elif isinstance(apertures[0].positions[0],np.ndarray): pos_x, pos_y = aper.positions[0] c_x = -(pos_x - center) * deltaPix #Lenstronomy defines x flipped, (i.e., East on the left.) c_y = (pos_y - center) * deltaPix if fix_n_list is not None: fix_n_list = np.array(fix_n_list) if i in fix_n_list[:,0]: fix_n_value = (fix_n_list[:,1])[fix_n_list[:,0]==i] if len(fix_n_value) != 1: raise ValueError("fix_n are not assigned correctly - {0} component have two assigned values.".format(i)) else: fix_n_value = fix_n_value[0] #extract the fix n value from the list fixed_source.append({'n_sersic': fix_n_value}) kwargs_source_init.append({'R_sersic': Reff, 'n_sersic': fix_n_value, 'e1': e1, 'e2': e2, 'center_x': c_x, 'center_y': c_y}) else: fixed_source.append({}) kwargs_source_init.append({'R_sersic': Reff, 'n_sersic': 2., 'e1': e1, 'e2': e2, 'center_x': c_x, 'center_y': c_y}) else: fixed_source.append({}) kwargs_source_init.append({'R_sersic': Reff, 'n_sersic': 2., 'e1': e1, 'e2': e2, 'center_x': c_x, 'center_y': c_y}) if fix_Re_list is not None: fix_Re_list = np.array(fix_Re_list) if i in fix_Re_list[:,0]: fix_Re_value = (fix_Re_list[:,1])[fix_Re_list[:,0]==i] if len(fix_Re_value) != 1: raise ValueError("fix_Re are not assigned correctly - {0} component have two assigned values.".format(i)) else: fix_Re_value = fix_Re_value[0] #extract the fix Re value from the list fixed_source[-1]['R_sersic'] = fix_Re_value kwargs_source_init[-1]['R_sersic'] = fix_Re_value kwargs_source_sigma.append({'n_sersic': 0.3, 'R_sersic': 0.2*deltaPix, 'e1': 0.1, 'e2': 0.1, 'center_x': 0.1*deltaPix, 'center_y': 0.1*deltaPix}) if apertures_center_focus == False: kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': deltaPix*0.05, 'n_sersic': 0.3, 'center_x': c_x-10*deltaPix, 'center_y': c_y-10*deltaPix}) kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': Reff*30, 'n_sersic': 9., 'center_x': c_x+10*deltaPix, 'center_y': c_y+10*deltaPix}) elif apertures_center_focus == True: kwargs_lower_source.append({'e1': -0.5, 'e2': -0.5, 'R_sersic': deltaPix*0.05, 'n_sersic': 0.3, 'center_x': c_x-2*deltaPix, 'center_y': c_y-2*deltaPix}) kwargs_upper_source.append({'e1': 0.5, 'e2': 0.5, 'R_sersic': Reff*30, 'n_sersic': 9., 'center_x': c_x+2*deltaPix, 'center_y': c_y+2*deltaPix}) source_params = [kwargs_source_init, kwargs_source_sigma, fixed_source, kwargs_lower_source, kwargs_upper_source] return source_params
[docs]def ps_params_generator(centers, deltaPix = 1): """ Quickly generate a point source parameters for the fitting. """ fixed_ps = [] kwargs_ps_init = [] kwargs_ps_sigma = [] kwargs_lower_ps = [] kwargs_upper_ps = [] for i in range(len(centers)): center_x = centers[i][0] * deltaPix center_y = centers[i][1] * deltaPix # point_amp = flux_list[i] fixed_ps.append({}) kwargs_ps_init.append({'ra_image': [center_x], 'dec_image': [center_y]}) # , 'point_amp': [point_amp]}) kwargs_ps_sigma.append({'ra_image': [0.5*deltaPix], 'dec_image': [0.5*deltaPix]}) kwargs_lower_ps.append({'ra_image': [center_x-2*deltaPix], 'dec_image': [center_y-2*deltaPix] } ) kwargs_upper_ps.append({'ra_image': [center_x+2*deltaPix], 'dec_image': [center_y+2*deltaPix] } ) ps_params = [kwargs_ps_init, kwargs_ps_sigma, fixed_ps, kwargs_lower_ps, kwargs_upper_ps] return ps_params