Source code for mri.optimizers.utils.reweight
# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Reweighting optimisation strategies.
"""
# Package import
from pysap.base.utils import flatten
# Third party import
import numpy as np
from modopt.math.stats import sigma_mad
[docs]class mReweight(object):
"""Ming reweighting.
This class implements the reweighting scheme described in Ming2017.
Parameters
----------
weights: numpy.ndarray
Array of weights
linear_op: pysap.numeric.linear.Wavelet
A linear operator.
thresh_factor: float, default 1
Threshold factor: sigma threshold.
"""
def __init__(self, weights, linear_op, thresh_factor=1):
self.weights = weights
self.original_weights = np.copy(self.weights)
self.thresh_factor = thresh_factor
self.linear_op = linear_op
[docs] def reweight(self, x_new):
"""Update the weights.
Parameters
----------
x_new: numpy.ndarray
the current primal solution.
Returns
-------
sigma_est: numpy.ndarray
the variance estimate on each scale.
"""
self.linear_op.op(x_new)
weights = np.empty((0, ), dtype=self.weights.dtype)
sigma_est = []
for scale in range(self.linear_op.transform.nb_scale):
bands_array, _ = flatten(self.linear_op.transform[scale])
if scale == (self.linear_op.transform.nb_scale - 1):
std_at_scale_i = 0.
else:
std_at_scale_i = sigma_mad(bands_array)
sigma_est.append(std_at_scale_i)
thr = np.ones(bands_array.shape, dtype=weights.dtype)
thr *= self.thresh_factor * std_at_scale_i
weights = np.concatenate((weights, thr))
self.weights = weights
return sigma_est