# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""This module defines primal dual optimizers."""
# System import
import time
# Package import
from .utils.reweight import mReweight
# Third party import
import numpy as np
from modopt.math.stats import sigma_mad
from modopt.opt.linear import Identity
from modopt.opt.algorithms import Condat
from modopt.opt.reweight import cwbReweight
[docs]def condatvu(gradient_op, linear_op, dual_regularizer, cost_op, kspace_generator=None, estimate_call_period=None,
max_nb_of_iter=150, tau=None, sigma=None, relaxation_factor=1.0,
x_init=None, std_est=None, std_est_method=None, std_thr=2.,
nb_of_reweights=1, metric_call_period=5, metrics={}, verbose=0):
"""Condat-Vu sparse reconstruction with reweightings.
Parameters
----------
gradient_op: instance of class GradBase
the gradient operator.
linear_op: instance of LinearBase
the linear operator: seek the sparsity, ie. a wavelet transform.
dual_regularizer: instance of ProximityParent
the dual regularization operator
cost_op: instance of costObj
the cost function used to check for convergence during the
optimization.
kspace_generator: instance of class BaseKspaceGenerator, default None
If not None, run the algorithm in an online way, where the data is
updated between iterations.
estimate_call_period: int, default None
In an online configuration (kspace_generator is defined),
retrieve partial results at this interval.
max_nb_of_iter: int, default 150
the maximum number of iterations in the Condat-Vu proximal-dual
splitting algorithm.
tau, sigma: float, default None
parameters of the Condat-Vu proximal-dual splitting algorithm.
If None estimates these parameters.
relaxation_factor: float, default 0.5
parameter of the Condat-Vu proximal-dual splitting algorithm.
If 1, no relaxation.
x_init: numpy.ndarray (optional, default None)
the initial guess of image
std_est: float, default None
the noise std estimate.
If None use the MAD as a consistent estimator for the std.
std_est_method: str, default None
if the standard deviation is not set, estimate this parameter using
the mad routine in the image ('primal') or in the sparse wavelet
decomposition ('dual') domain.
std_thr: float, default 2.
use this treshold expressed as a number of sigma in the residual
proximity operator during the thresholding.
relaxation_factor: float, default 0.5
parameter of the Condat-Vu proximal-dual splitting algorithm.
If 1, no relaxation.
nb_of_reweights: int, default 1
the number of reweightings.
metric_call_period: int (default 5)
the period on which the metrics are compute.
metrics: dict (optional, default None)
the list of desired convergence metrics: {'metric_name':
[@metric, metric_parameter]}. See modopt for the metrics API.
verbose: int, default 0
the verbosity level.
Returns
-------
x_final: numpy.ndarray
the estimated CONDAT-VU solution.
costs: list of float
the cost function values.
metrics: dict
the requested metrics values during the optimization.
y_final: numpy.ndarray
the estimated dual CONDAT-VU solution
"""
# Check inputs
start = time.perf_counter()
if std_est_method not in (None, "primal", "dual"):
raise ValueError(
"Unrecognize std estimation method '{0}'.".format(std_est_method))
# Define the initial primal and dual solutions
if x_init is None:
x_init = np.squeeze(np.zeros((linear_op.n_coils,
*gradient_op.fourier_op.shape),
dtype=np.complex64))
primal = x_init
dual = linear_op.op(primal)
weights = dual
# Define the weights used during the thresholding in the dual domain,
# the reweighting strategy, and the prox dual operator
# Case1: estimate the noise std in the image domain
if std_est_method == "primal":
if std_est is None:
std_est = sigma_mad(gradient_op.MtX(gradient_op.obs_data))
weights[...] = std_thr * std_est
reweight_op = cwbReweight(weights)
dual_regularizer.weights = reweight_op.weights
# Case2: estimate the noise std in the sparse wavelet domain
elif std_est_method == "dual":
if std_est is None:
std_est = 0.0
weights[...] = std_thr * std_est
reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
dual_regularizer.weights = reweight_op.weights
# Case3: manual regularization mode, no reweighting
else:
reweight_op = None
nb_of_reweights = 0
# Define the Condat Vu optimizer: define the tau and sigma in the
# Condat-Vu proximal-dual splitting algorithm if not already provided.
# Check also that the combination of values will lead to convergence.
norm = linear_op.l2norm(x_init.shape)
lipschitz_cst = gradient_op.spec_rad
if sigma is None:
sigma = 0.5
if tau is None:
# to avoid numerics troubles with the convergence bound
eps = 1.0e-8
# due to the convergence bound
tau = 1.0 / (lipschitz_cst/2 + sigma * norm**2 + eps)
convergence_test = (
1.0 / tau - sigma * norm ** 2 >= lipschitz_cst / 2.0)
# Welcome message
if verbose > 0:
print(" - mu: ", dual_regularizer.weights)
print(" - lipschitz constant: ", gradient_op.spec_rad)
print(" - tau: ", tau)
print(" - sigma: ", sigma)
print(" - rho: ", relaxation_factor)
print(" - std: ", std_est)
print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
print(" - data: ", gradient_op.fourier_op.shape)
if hasattr(linear_op, "nb_scale"):
print(" - wavelet: ", linear_op, "-", linear_op.nb_scale)
print(" - max iterations: ", max_nb_of_iter)
print(" - number of reweights: ", nb_of_reweights)
print(" - primal variable shape: ", primal.shape)
print(" - dual variable shape: ", dual.shape)
print("-" * 40)
prox_op = Identity()
# Define the optimizer
opt = Condat(
x=primal,
y=dual,
grad=gradient_op,
prox=prox_op,
prox_dual=dual_regularizer,
linear=linear_op,
cost=cost_op,
rho=relaxation_factor,
sigma=sigma,
tau=tau,
rho_update=None,
sigma_update=None,
tau_update=None,
auto_iterate=False,
metric_call_period=metric_call_period,
metrics=metrics)
cost_op = opt._cost_func
# Perform the first reconstruction
if verbose > 0:
print("Starting optimization...")
if kspace_generator is None:
opt.iterate(max_iter=max_nb_of_iter)
else:
kspace_generator.opt_iterate(opt, estimate_call_period=estimate_call_period)
# Loop through the number of reweightings
for reweight_index in range(nb_of_reweights):
# Generate the new weights following reweighting prescription
if std_est_method == "primal":
reweight_op.reweight(linear_op.op(opt._x_new))
else:
std_est = reweight_op.reweight(opt._x_new)
# Welcome message
if verbose > 0:
print(" - reweight: ", reweight_index + 1)
print(" - std: ", std_est)
# Update the weights in the dual proximity operator
dual_regularizer.weights = reweight_op.weights
# Perform optimisation with new weights
if kspace_generator is None:
opt.iterate(max_iter=max_nb_of_iter)
else:
kspace_generator.reset()
kspace_generator.opt_iterate(max_iter=max_nb_of_iter)
# Goodbye message
end = time.perf_counter()
if verbose > 0:
if hasattr(cost_op, "cost"):
print(" - final iteration number: ", cost_op._iteration)
print(" - final cost value: ", cost_op.cost)
print(" - converged: ", opt.converge)
print("Done.")
print("Execution time: ", end - start, " seconds")
print("-" * 40)
# Get the final solution
x_final = opt._x_new
y_final = opt._y_new
if hasattr(cost_op, "cost"):
costs = cost_op._cost_list
else:
costs = None
return x_final, costs, opt.metrics, y_final