Source code for mri.optimizers.forward_backward

# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""FISTA or POGM MRI reconstruction."""

# Third party import
import numpy as np
from modopt.opt.algorithms import ForwardBackward, POGM

from .base import run_algorithm, run_online_algorithm


[docs]def fista(gradient_op, linear_op, prox_op, cost_op, kspace_generator=None, estimate_call_period=None, lambda_init=1.0, max_nb_of_iter=300, x_init=None, metric_call_period=5, metrics={}, verbose=0, **lambda_update_params): """FISTA sparse reconstruction. Parameters ---------- gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. prox_op: instance of ProximityParent the proximal operator. cost_op: instance of costObj the cost function used to check for convergence during the optimization. kspace_generator: instance of class BaseKspaceGenerator, default None If not None, run the algorithm in an online way, where the data is updated between iterations. estimate_call_period: int, default None In an online configuration (kspace_generator is defined), retrieve partial results at this interval. lambda_init: float, (default 1.0) initial value for the FISTA step. max_nb_of_iter: int (optional, default 300) the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. x_init: numpy.ndarray (optional, default None) Inital guess for the image metric_call_period: int (default 5) the period on which the metrics are compute. metrics: dict (optional, default None) the list of desired convergence metrics: {'metric_name': [@metric, metric_parameter]}. See modopt for the metrics API. verbose: int (optional, default 0) the verbosity level. lambda_update_params: dict, Parameters for the lambda update in FISTA mode Returns ------- x_final: numpy.ndarray the estimated FISTA solution. costs: list of float the cost function values. metrics: dict the requested metrics values during the optimization. """ # Define the initial primal and dual solutions if x_init is None: x_init = np.squeeze(np.zeros((gradient_op.linear_op.n_coils, *gradient_op.fourier_op.shape), dtype=np.complex64)) alpha_init = linear_op.op(x_init) # Welcome message if verbose > 0: print(" - mu: ", prox_op.weights) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - data: ", gradient_op.fourier_op.shape) if hasattr(linear_op, "nb_scale"): print(" - wavelet: ", linear_op, "-", linear_op.nb_scale) print(" - max iterations: ", max_nb_of_iter) print(" - image variable shape: ", gradient_op.fourier_op.shape) print(" - alpha variable shape: ", alpha_init.shape) print("-" * 40) beta_param = gradient_op.inv_spec_rad if lambda_update_params.get("restart_strategy") == "greedy": lambda_update_params["min_beta"] = gradient_op.inv_spec_rad # this value is the recommended one by J. Liang in his article # when introducing greedy FISTA. # ref: https://arxiv.org/pdf/1807.04005.pdf beta_param *= 1.3 # Define the optimizer opt = ForwardBackward( x=alpha_init, grad=gradient_op, prox=prox_op, cost=cost_op, auto_iterate=False, metric_call_period=metric_call_period, metrics=metrics, linear=linear_op, lambda_param=lambda_init, beta_param=beta_param, **lambda_update_params) if kspace_generator is not None: return run_online_algorithm(opt, kspace_generator, estimate_call_period, verbose) return run_algorithm(opt, max_nb_of_iter, verbose)
[docs]def pogm(gradient_op, linear_op, prox_op, cost_op=None, kspace_generator=None, estimate_call_period=None, max_nb_of_iter=300, x_init=None, metric_call_period=5, sigma_bar=0.96, metrics={}, verbose=0): """ Perform sparse reconstruction using the POGM algorithm. Parameters ---------- gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. prox_op: instance of ProximityParent the proximal operator. cost_op: instance of costObj, (default None) the cost function used to check for convergence during the optimization. kspace_generator: instance of BaseKspaceGenerator, default None If not None, use it to perform an online reconstruction. estimate_call_period: int, default None In an online configuration (kspace_generator is defined), retrieve partial results at this interval. lambda_init: float, (default 1.0) initial value for the FISTA step. max_nb_of_iter: int (optional, default 300) the maximum number of iterations in the POGM algorithm. x_init: numpy.ndarray (optional, default None) the initial guess of image metric_call_period: int (default 5) the period on which the metrics are computed. metrics: dict (optional, default None) the list of desired convergence metrics: {'metric_name': [@metric, metric_parameter]}. See modopt for the metrics API. verbose: int (optional, default 0) the verbosity level. Returns ------- x_final: numpy.ndarray the estimated POGM solution. costs: list of float the cost function values. metrics: dict the requested metrics values during the optimization. """ # Define the initial values im_shape = (gradient_op.linear_op.n_coils, *gradient_op.fourier_op.shape) if x_init is None: alpha_init = linear_op.op(np.squeeze(np.zeros(im_shape, dtype=np.complex64))) else: alpha_init = linear_op.op(x_init) # Welcome message if verbose > 0: print(" - mu: ", prox_op.weights) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - data: ", gradient_op.fourier_op.shape) if hasattr(linear_op, "nb_scale"): print(" - wavelet: ", linear_op, "-", linear_op.nb_scale) print(" - max iterations: ", max_nb_of_iter) print(" - image variable shape: ", im_shape) print("-" * 40) # Hyper-parameters beta = gradient_op.inv_spec_rad opt = POGM( u=alpha_init, x=alpha_init, y=alpha_init, z=alpha_init, grad=gradient_op, prox=prox_op, cost=cost_op, linear=linear_op, beta_param=beta, sigma_bar=sigma_bar, metric_call_period=metric_call_period, metrics=metrics, auto_iterate=False, ) if kspace_generator is not None: return run_online_algorithm(opt, kspace_generator, estimate_call_period, verbose) return run_algorithm(opt, max_nb_of_iter, verbose=verbose)