Source code for mri.reconstructors.base

# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System import
import warnings

# Third party import
from modopt.opt.linear import Identity
import numpy as np

# Package import
from ..optimizers import condatvu, fista, pogm
from ..optimizers.utils.cost import GenericCost
from ..generators.base import KspaceGeneratorBase

OPTIMIZERS = {
    'fista': fista,
    'pogm': pogm,
    'condatvu': condatvu,
}


[docs]class ReconstructorBase(object): """Base reconstructor class for reconstruction. This class holds some parameters that are common for all MR Image reconstructors. Notes ----- For the Analysis case, finds the solution for x of: ..math:: (1/2) * ||F x - y||^2_2 + mu * H (W x) For the Synthesis case, finds the solution of: ..math:: (1/2) * ||F Wt alpha - y||^2_2 + mu * H(alpha) with ..math:: alpha = W x and x = Wt alpha Parameters ---------- fourier_op: instance of FourierOperatorBase Defines the fourier operator :math:`F` in the above equation linear_op: object Defines the linear sparsifying operator denoted :math:`W` in the above equation. This must operate on x and have 2 functions, op(x) and adj_op(coeff) which implements the operator and adjoint operator. For wavelets, this can be object of class WaveletN or WaveletUD2 from `mri.operators.linear` regularizer_op: operator, (optional default None) Defines the regularization operator for the regularization function H. If None, the regularization chosen is Identity and the optimization turns to gradient descent. Defines :math:`H` in the above equation. gradient_formulation: str between 'analysis' or 'synthesis', default 'synthesis' defines the formulation of the image model which defines the gradient. grad_class: Gradient class from `mri.operators.gradient`. Points to the gradient class based on the MR Image model and gradient_formulation. init_gradient_op: bool, default True This parameter controls whether the gradient operator must be initialized right now. If set to false, the user needs to call initialize_gradient_op to initialize the gradient at right time before reconstruction verbose: int, optional default 0 Verbosity levels 1 => Print basic debug information 5 => Print all initialization information 20 => Calculate cost at the end of each iteration. 30 => Print the debug information of operators if defined by class NOTE - High verbosity (>20) levels are computationally intensive. extra_grad_args: Extra Keyword arguments for gradient initialization This holds the initialization parameters used for gradient initialization which is obtained from 'grad_class'. Please refer to mri.operators.gradient.base for reference. In case of synthesis formulation, the 'linear_op' is also passed as an extra arg """ def __init__(self, fourier_op, linear_op, regularizer_op, gradient_formulation, grad_class, init_gradient_op=True, verbose=0, **extra_grad_args): self.fourier_op = fourier_op self.linear_op = linear_op self.prox_op = regularizer_op self.gradient_method = gradient_formulation self.grad_class = grad_class self.verbose = verbose self.extra_grad_args = extra_grad_args if regularizer_op is None: warnings.warn("The prox_op is not set. Setting to identity. " "Note that optimization is just a gradient descent.") self.prox_op = Identity() # TODO try to not use gradient_formulation and # rely on static attributes # If the reconstruction formulation is synthesis, # we send the linear operator as well. if gradient_formulation == 'synthesis': self.extra_grad_args['linear_op'] = self.linear_op if init_gradient_op: self.initialize_gradient_op(**self.extra_grad_args)
[docs] def initialize_gradient_op(self, **extra_args): # Initialize gradient operator and cost operators self.gradient_op = self.grad_class( fourier_op=self.fourier_op, verbose=self.verbose, **extra_args, )
[docs] def reconstruct(self, kspace_data, optimization_alg='pogm', x_init=None, num_iterations=100, cost_op_kwargs=None, **kwargs): """Perform the base reconstruction. Parameters ---------- kspace_data: numpy.ndarray or KspaceGeneratorBase the acquired value in the Fourier domain. this is y in above equation. optimization_alg: str (optional, default 'pogm') Type of optimization algorithm to use, 'pogm' | 'fista' | 'condatvu' x_init: numpy.ndarray (optional, default None) input initial guess image for reconstruction. If None, the initialization will be an ndarray of zeros num_iterations: int (optional, default 100) number of iterations of algorithm cost_op_kwargs: dict (optional, default None) specifies the extra keyword arguments for cost operations. please refer to modopt.opt.cost.costObj for details. kwargs: extra keyword arguments for modopt algorithm Please refer to corresponding ModOpt algorithm class for details. https://github.com/CEA-COSMIC/ModOpt/blob/master/\ modopt/opt/algorithms.py """ kspace_generator = None if isinstance(kspace_data, np.ndarray): self.gradient_op.obs_data = kspace_data elif isinstance(kspace_data, KspaceGeneratorBase): kspace_generator = kspace_data self.gradient_op._obs_data, self.fourier_op.mask = kspace_generator[0] try: optimizer = OPTIMIZERS[optimization_alg] except KeyError as e: raise ValueError("optimization_alg must be one of " + str(list(OPTIMIZERS.keys()))) from e if optimization_alg == "condatvu": kwargs["dual_regularizer"] = self.prox_op optimizer_type = 'primal_dual' else: kwargs["prox_op"] = self.prox_op optimizer_type = 'forward_backward' if cost_op_kwargs is None: cost_op_kwargs = {} self.cost_op = GenericCost( gradient_op=self.gradient_op, linear_op=self.linear_op, prox_op=self.prox_op, verbose=self.verbose >= 20, optimizer_type=optimizer_type, **cost_op_kwargs, ) self.x_final, self.costs, *metrics = optimizer( kspace_generator=kspace_generator, gradient_op=self.gradient_op, linear_op=self.linear_op, cost_op=self.cost_op, max_nb_of_iter=num_iterations, x_init=x_init, verbose=self.verbose, **kwargs) if optimization_alg == 'condatvu': self.metrics, self.y_final = metrics else: self.metrics = metrics[0] return self.x_final, self.costs, self.metrics