# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""DECONVOLVE.
This module defines functions to perform galaxy image deconvolution using the
Condat-Vu algorithm.
"""
# System import
import pysap
from pysap.plugins.astro.deconvolution.linear import WaveletConvolve2
from pysap.plugins.astro.deconvolution.wavelet_filters import get_cospy_filters
from pysap.utils import condatvu_logo
# Third party import
import numpy as np
from modopt.base.np_adjust import rotate
from modopt.opt.algorithms import Condat
from modopt.opt.cost import costObj
from modopt.opt.gradient import GradBasic
from modopt.opt.proximity import Positivity, SparseThreshold
from modopt.opt.reweight import cwbReweight
from modopt.math.convolve import convolve
from modopt.math.stats import sigma_mad
from modopt.signal.wavelet import filter_convolve
[docs]def psf_convolve(data, psf, psf_rot=False):
"""PSF Convolution.
Convolve the input data with the PSF provided.
Parameters
----------
data : numpy.ndarray
Input data, 2D image
psf : numpy.ndarray
Input PSF, 2D image
psf_rot : bool, optional
Option to rotate the input PSF, default is ``False``
Returns
-------
numpy.ndarray
Convolved image
"""
if psf_rot:
psf = rotate(psf)
return convolve(data, psf)
[docs]def get_weights(data, psf, filters, wave_thresh_factor=np.array([3, 3, 4])):
"""Get Sparsity Weights.
Get the weights needed for the sparse regularisation term in the
deconvolution problem.
Parameters
----------
data : numpy.ndarray
Input data, 2D image
psf : numpy.ndarray
Input PSF, 2D image
filters : numpy.ndarray
Wavelet filters
wave_thresh_factor : numpy.ndarray, optional
Threshold factors for each wavelet scale, default is
``np.array([3, 3, 4])``
Returns
-------
numpy.ndarray
Weights
"""
noise_est = sigma_mad(data)
filter_conv = filter_convolve(np.rot90(psf, 2), filters)
filter_norm = np.array([
np.linalg.norm(a) * b * np.ones(data.shape)
for a, b in zip(filter_conv, wave_thresh_factor)
])
return noise_est * filter_norm
[docs]def sparse_deconv_condatvu(
data,
psf,
n_iter=300,
n_reweights=1,
verbose=False,
progress=True,
):
"""Sparse Deconvolution with Condat-Vu.
Perform deconvolution using sparse regularisation with the Condat-Vu
algorithm.
Parameters
----------
data : numpy.ndarray
Input data, 2D image
psf : numpy.ndarray
Input PSF, 2D image
n_iter : int, optional
Maximum number of iterations, default is ``300``
n_reweights : int, optional
Number of reweightings, default is ``1``
verbose : bool, optional
Verbosity option, default is ``True``
progress : bool, optional
Option to show progress bar, default is ``True``
Returns
-------
numpy.ndarray
Deconvolved image
"""
# Print the algorithm set-up
if verbose:
print(condatvu_logo())
# Define the wavelet filters
filters = get_cospy_filters(
data.shape,
transform_name='LinearWaveletTransformATrousAlgorithm'
)
# Set the reweighting scheme
reweight = cwbReweight(get_weights(data, psf, filters))
# Set the initial variable values
primal = np.ones(data.shape)
dual = np.ones(filters.shape)
# Set the gradient operators
grad_op = GradBasic(
data,
lambda x: psf_convolve(x, psf),
lambda x: psf_convolve(x, psf, psf_rot=True),
)
# Set the linear operator
linear_op = WaveletConvolve2(filters)
# Set the proximity operators
prox_op = Positivity()
prox_dual_op = SparseThreshold(linear_op, reweight.weights)
# Set the cost function
cost_op = costObj(
[grad_op, prox_op, prox_dual_op],
tolerance=1e-6,
cost_interval=1,
plot_output=True,
verbose=verbose,
)
# Set the optimisation algorithm
alg = Condat(
primal,
dual,
grad_op,
prox_op,
prox_dual_op,
linear_op,
cost_op,
rho=0.8,
sigma=0.5,
tau=0.5,
auto_iterate=False,
progress=progress,
)
# Run the algorithm
alg.iterate(max_iter=n_iter)
# Implement reweigting
for rw_num in range(n_reweights):
reweight.reweight(linear_op.op(alg.x_final))
alg.iterate(max_iter=n_iter)
# Return the final result
return alg.x_final