# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
# System import
import warnings
# Third party import
from modopt.opt.linear import Identity
import numpy as np
# Package import
from ..optimizers import condatvu, fista, pogm
from ..optimizers.utils.cost import GenericCost
from ..generators.base import KspaceGeneratorBase
OPTIMIZERS = {
'fista': fista,
'pogm': pogm,
'condatvu': condatvu,
}
[docs]class ReconstructorBase(object):
"""Base reconstructor class for reconstruction.
This class holds some parameters that are common for all MR Image
reconstructors.
Notes
-----
For the Analysis case, finds the solution for x of:
..math:: (1/2) * ||F x - y||^2_2 + mu * H (W x)
For the Synthesis case, finds the solution of:
..math:: (1/2) * ||F Wt alpha - y||^2_2 + mu * H(alpha)
with ..math:: alpha = W x and x = Wt alpha
Parameters
----------
fourier_op: instance of FourierOperatorBase
Defines the fourier operator :math:`F` in the above equation
linear_op: object
Defines the linear sparsifying operator denoted :math:`W` in the above equation. This must operate on x and
have 2 functions, op(x) and adj_op(coeff) which implements the
operator and adjoint operator. For wavelets, this can be object of
class WaveletN or WaveletUD2 from `mri.operators.linear`
regularizer_op: operator, (optional default None)
Defines the regularization operator for the regularization function H.
If None, the regularization chosen is Identity and the optimization
turns to gradient descent. Defines :math:`H` in the above equation.
gradient_formulation: str between 'analysis' or 'synthesis',
default 'synthesis'
defines the formulation of the image model which defines the gradient.
grad_class: Gradient class from `mri.operators.gradient`.
Points to the gradient class based on the MR Image model and
gradient_formulation.
init_gradient_op: bool, default True
This parameter controls whether the gradient operator must be
initialized right now.
If set to false, the user needs to call initialize_gradient_op to
initialize the gradient at right time before reconstruction
verbose: int, optional default 0
Verbosity levels
1 => Print basic debug information
5 => Print all initialization information
20 => Calculate cost at the end of each iteration.
30 => Print the debug information of operators if defined by class
NOTE - High verbosity (>20) levels are computationally intensive.
extra_grad_args: Extra Keyword arguments for gradient initialization
This holds the initialization parameters used for gradient
initialization which is obtained from 'grad_class'.
Please refer to mri.operators.gradient.base for reference.
In case of synthesis formulation, the 'linear_op' is also passed as
an extra arg
"""
def __init__(self, fourier_op, linear_op, regularizer_op,
gradient_formulation, grad_class, init_gradient_op=True,
verbose=0, **extra_grad_args):
self.fourier_op = fourier_op
self.linear_op = linear_op
self.prox_op = regularizer_op
self.gradient_method = gradient_formulation
self.grad_class = grad_class
self.verbose = verbose
self.extra_grad_args = extra_grad_args
if regularizer_op is None:
warnings.warn("The prox_op is not set. Setting to identity. "
"Note that optimization is just a gradient descent.")
self.prox_op = Identity()
# TODO try to not use gradient_formulation and
# rely on static attributes
# If the reconstruction formulation is synthesis,
# we send the linear operator as well.
if gradient_formulation == 'synthesis':
self.extra_grad_args['linear_op'] = self.linear_op
if init_gradient_op:
self.initialize_gradient_op(**self.extra_grad_args)
[docs] def initialize_gradient_op(self, **extra_args):
# Initialize gradient operator and cost operators
self.gradient_op = self.grad_class(
fourier_op=self.fourier_op,
verbose=self.verbose,
**extra_args,
)
[docs] def reconstruct(self, kspace_data, optimization_alg='pogm',
x_init=None, num_iterations=100, cost_op_kwargs=None,
**kwargs):
"""Perform the base reconstruction.
Parameters
----------
kspace_data: numpy.ndarray or KspaceGeneratorBase
the acquired value in the Fourier domain.
this is y in above equation.
optimization_alg: str (optional, default 'pogm')
Type of optimization algorithm to use, 'pogm' | 'fista' |
'condatvu'
x_init: numpy.ndarray (optional, default None)
input initial guess image for reconstruction. If None, the
initialization will be an ndarray of zeros
num_iterations: int (optional, default 100)
number of iterations of algorithm
cost_op_kwargs: dict (optional, default None)
specifies the extra keyword arguments for cost operations.
please refer to modopt.opt.cost.costObj for details.
kwargs: extra keyword arguments for modopt algorithm
Please refer to corresponding ModOpt algorithm class for details.
https://github.com/CEA-COSMIC/ModOpt/blob/master/\
modopt/opt/algorithms.py
"""
kspace_generator = None
if isinstance(kspace_data, np.ndarray):
self.gradient_op.obs_data = kspace_data
elif isinstance(kspace_data, KspaceGeneratorBase):
kspace_generator = kspace_data
self.gradient_op._obs_data, self.fourier_op.mask = kspace_generator[0]
try:
optimizer = OPTIMIZERS[optimization_alg]
except KeyError as e:
raise ValueError("optimization_alg must be one of "
+ str(list(OPTIMIZERS.keys()))) from e
if optimization_alg == "condatvu":
kwargs["dual_regularizer"] = self.prox_op
optimizer_type = 'primal_dual'
else:
kwargs["prox_op"] = self.prox_op
optimizer_type = 'forward_backward'
if cost_op_kwargs is None:
cost_op_kwargs = {}
self.cost_op = GenericCost(
gradient_op=self.gradient_op,
linear_op=self.linear_op,
prox_op=self.prox_op,
verbose=self.verbose >= 20,
optimizer_type=optimizer_type,
**cost_op_kwargs,
)
self.x_final, self.costs, *metrics = optimizer(
kspace_generator=kspace_generator,
gradient_op=self.gradient_op,
linear_op=self.linear_op,
cost_op=self.cost_op,
max_nb_of_iter=num_iterations,
x_init=x_init,
verbose=self.verbose,
**kwargs)
if optimization_alg == 'condatvu':
self.metrics, self.y_final = metrics
else:
self.metrics = metrics[0]
return self.x_final, self.costs, self.metrics