# -*- coding: utf-8 -*-########################################################################### pySAP - Copyright (C) CEA, 2017 - 2018# Distributed under the terms of the CeCILL-B license, as published by# the CEA-CNRS-INRIA. Refer to the LICENSE file or to# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html# for details.##########################################################################"""DECONVOLVE.This module defines functions to perform galaxy image deconvolution using theCondat-Vu algorithm."""# System importimportpysapfrompysap.plugins.astro.deconvolution.linearimportWaveletConvolve2frompysap.plugins.astro.deconvolution.wavelet_filtersimportget_cospy_filtersfrompysap.utilsimportcondatvu_logo# Third party importimportnumpyasnpfrommodopt.base.np_adjustimportrotatefrommodopt.opt.algorithmsimportCondatfrommodopt.opt.costimportcostObjfrommodopt.opt.gradientimportGradBasicfrommodopt.opt.proximityimportPositivity,SparseThresholdfrommodopt.opt.reweightimportcwbReweightfrommodopt.math.convolveimportconvolvefrommodopt.math.statsimportsigma_madfrommodopt.signal.waveletimportfilter_convolve
[docs]defpsf_convolve(data,psf,psf_rot=False):"""PSF Convolution. Convolve the input data with the PSF provided. Parameters ---------- data : numpy.ndarray Input data, 2D image psf : numpy.ndarray Input PSF, 2D image psf_rot : bool, optional Option to rotate the input PSF, default is ``False`` Returns ------- numpy.ndarray Convolved image """ifpsf_rot:psf=rotate(psf)returnconvolve(data,psf)
[docs]defget_weights(data,psf,filters,wave_thresh_factor=np.array([3,3,4])):"""Get Sparsity Weights. Get the weights needed for the sparse regularisation term in the deconvolution problem. Parameters ---------- data : numpy.ndarray Input data, 2D image psf : numpy.ndarray Input PSF, 2D image filters : numpy.ndarray Wavelet filters wave_thresh_factor : numpy.ndarray, optional Threshold factors for each wavelet scale, default is ``np.array([3, 3, 4])`` Returns ------- numpy.ndarray Weights """noise_est=sigma_mad(data)filter_conv=filter_convolve(np.rot90(psf,2),filters)filter_norm=np.array([np.linalg.norm(a)*b*np.ones(data.shape)fora,binzip(filter_conv,wave_thresh_factor)])returnnoise_est*filter_norm
[docs]defsparse_deconv_condatvu(data,psf,n_iter=300,n_reweights=1,verbose=False,progress=True,):"""Sparse Deconvolution with Condat-Vu. Perform deconvolution using sparse regularisation with the Condat-Vu algorithm. Parameters ---------- data : numpy.ndarray Input data, 2D image psf : numpy.ndarray Input PSF, 2D image n_iter : int, optional Maximum number of iterations, default is ``300`` n_reweights : int, optional Number of reweightings, default is ``1`` verbose : bool, optional Verbosity option, default is ``True`` progress : bool, optional Option to show progress bar, default is ``True`` Returns ------- numpy.ndarray Deconvolved image """# Print the algorithm set-upifverbose:print(condatvu_logo())# Define the wavelet filtersfilters=get_cospy_filters(data.shape,transform_name='LinearWaveletTransformATrousAlgorithm')# Set the reweighting schemereweight=cwbReweight(get_weights(data,psf,filters))# Set the initial variable valuesprimal=np.ones(data.shape)dual=np.ones(filters.shape)# Set the gradient operatorsgrad_op=GradBasic(data,lambdax:psf_convolve(x,psf),lambdax:psf_convolve(x,psf,psf_rot=True),)# Set the linear operatorlinear_op=WaveletConvolve2(filters)# Set the proximity operatorsprox_op=Positivity()prox_dual_op=SparseThreshold(linear_op,reweight.weights)# Set the cost functioncost_op=costObj([grad_op,prox_op,prox_dual_op],tolerance=1e-6,cost_interval=1,plot_output=True,verbose=verbose,)# Set the optimisation algorithmalg=Condat(primal,dual,grad_op,prox_op,prox_dual_op,linear_op,cost_op,rho=0.8,sigma=0.5,tau=0.5,auto_iterate=False,progress=progress,)# Run the algorithmalg.iterate(max_iter=n_iter)# Implement reweigtingforrw_numinrange(n_reweights):reweight.reweight(linear_op.op(alg.x_final))alg.iterate(max_iter=n_iter)# Return the final resultreturnalg.x_final