import numpy as np
from modopt.opt.proximity import SparseThreshold
from modopt.opt.linear import Identity
[docs]class WeightedSparseThreshold(SparseThreshold):
"""This is a weighted version of `SparseThreshold` in ModOpt.
When chosen `scale_based`, it allows the users to specify an array of
weights W[i] and each weight is assigen to respective scale `i`.
Also, custom weights can be defined.
Note that the weights on coarse scale is always set to 0
Parameters
----------
weights : numpy.ndarray
Input array of weights or a tuple holding base weight W and power P
coeffs_shape: tuple
The shape of linear coefficients
weight_type : string 'custom' | 'scale_based' | 'custom_scale',
default 'scale_based'
Mode of operation of proximity:
custom -> custom array of weights
scale_based -> custom weights applied per scale
zero_weight_coarse: bool, default True
linear: object, default `Identity()`
Linear operator, to be used in cost function evaluation
See Also
--------
SparseThreshold : parent class
"""
def __init__(self, weights, coeffs_shape, weight_type='scale_based',
zero_weight_coarse=True, linear=Identity(), **kwargs):
self.cf_shape = coeffs_shape
self.weight_type = weight_type
available_weight_type = ('scale_based', 'custom')
if self.weight_type not in available_weight_type:
raise ValueError('Weight type must be one of ' +
' '.join(available_weight_type))
self.zero_weight_coarse = zero_weight_coarse
self.mu = weights
super(WeightedSparseThreshold, self).__init__(
weights=self.mu,
linear=linear,
**kwargs
)
@property
def mu(self):
"""`mu` is the weights used for thresholding"""
return self.weights
@mu.setter
def mu(self, w):
"""Update `mu`, based on `coeffs_shape` and `weight_type`"""
weights_init = np.zeros(np.sum(np.prod(self.cf_shape, axis=-1)))
start = 0
if self.weight_type == 'scale_based':
scale_shapes = np.unique(self.cf_shape, axis=0)
num_scales = len(scale_shapes)
if isinstance(w, (float, int, np.float64)):
weights = w * np.ones(num_scales)
else:
if len(w) != num_scales:
raise ValueError('The number of weights dont match '
'the number of scales')
weights = w
for i, scale_shape in enumerate(np.unique(self.cf_shape, axis=0)):
scale_sz = np.prod(scale_shape)
stop = start + scale_sz * np.sum(scale_shape == self.cf_shape)
weights_init[start:stop] = weights[i]
start = stop
elif self.weight_type == 'custom':
if isinstance(w, (float, int, np.float64)):
w = w * np.ones(weights_init.shape[0])
weights_init = w
if self.zero_weight_coarse:
weights_init[:np.prod(self.cf_shape[0])] = 0
self.weights = weights_init
[docs]def _sigma_mad(data, centered=True):
"""Return a robust estimation of the standard deviation.
The standard deviation is computed using the following estimator, based on the
Median Absolute deviation of the data [#]_
.. math::
\hat{\sigma} = \frac{MAD}{\sqrt{2}\textrm{erf}^{-1}(1/2)}
Parameters
----------
data: numpy.ndarray
the data on which the standard deviation will be estimated.
centered: bool, default True.
If true the median of the is assummed to be 0.
Returns
-------
float:
The estimation of the standard deviation.
References
----------
.. [#] https://en.m.wikipedia.org/wiki/Median_absolute_deviation
"""
if centered:
return np.median(np.abs(data[:]))/0.6745
return np.median(np.abs(data[:] - np.median(data[:])))/0.6745
[docs]def _sure_est(data):
"""Return an estimation of the threshold computed using the SURE method.
The computation of the estimator is based on the formulation of `cite:donoho1994`
and the efficient implementation of [#]_
Parameters
----------
data: numpy.array
Noisy Data with unit standard deviation.
Returns
-------
float
Value of the threshold minimizing the SURE estimator.
References
----------
.. [#] https://pyyawt.readthedocs.io/_modules/pyyawt/denoising.html#ValSUREThresh
"""
dataf = data.flatten()
n = dataf.size
data_sorted = np.sort(np.abs(dataf))**2
idx = np.arange(n-1, -1, -1)
tmp = np.cumsum(data_sorted) + idx * data_sorted
risk = (n - (2 * np.arange(n)) + tmp) / n
ibest = np.argmin(risk)
return np.sqrt(data_sorted[ibest])
[docs]def _thresh_select(data, thresh_est):
"""
Threshold selection for denoising, implementing the methods proposed in `cite:donoho1994`
Parameters
----------
data: numpy.ndarray
Noisy data on which a threshold will be estimated. It should only be corrupted by a
standard gaussian white noise N(0,1).
thresh_est: str
threshold estimation method. Available are "sure", "universal", "hybrid-sure".
Returns
-------
float:
the threshold for the data provided.
"""
n = data.size
universal_thr = np.sqrt(2*np.log(n))
if thresh_est == "sure":
thr = _sure_est(data)
elif thresh_est == "universal":
thr = universal_thr
elif thresh_est == "hybrid-sure":
eta = np.sum(data ** 2) /n - 1
if eta < (np.log2(n) ** 1.5) / np.sqrt(n):
thr = universal_thr
else:
test_th = _sure_est(data)
thr = min(test_th, universal_thr)
else:
raise ValueError(
"Unsupported threshold method."
"Available are 'sure', 'universal' and 'hybrid-sure'"
)
return thr
[docs]def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est):
r"""Return an estimate of the noise standard deviation in each subband.
Parameters
----------
wavelet_coeffs: numpy.ndarray
flatten array of wavelet coefficients, typically returned by ``WaveletN.op``
coeffs_shape:
list of tuple representing the shape of each subband.
Typically accessible by WaveletN.coeffs_shape
sigma_est: str
Estimation method, available are "band", "scale", and "global"
Returns
-------
numpy.ndarray
Estimation of the variance for each wavelet subband.
Notes
-----
This methods makes several assumptions:
- The wavelet coefficients are ordered by scale, and the scales are ordered by size.
- At each scale, the subbands should have the same shape.
The variance estimation is either performed:
- On each subband (``sigma_est = "band"``)
- On each scale, using the detailled HH subband. (``sigma_est = "scale"``)
- Only with the largest, most detailled HH band (``sigma_est = "global"``)
See Also
--------
_sigma_mad: function estimating the standard deviation.
"""
sigma_ret = np.ones(len(coeffs_shape))
sigma_ret[0] = np.NaN
start = 0
stop = 0
if sigma_est is None:
return sigma_ret
elif sigma_est == "band":
for i in range(1, len(coeffs_shape)):
stop += np.prod(coeffs_shape[i])
sigma_ret[i] = _sigma_mad(wavelet_coeffs[start:stop])
start = stop
elif sigma_est == "scale":
# use the diagonal coefficients subband to estimate the variance of the scale.
# it assumes that the band of the same scale have the same shape.
start = np.prod(coeffs_shape[0])
for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)):
scale_sz = np.prod(scale_shape)
matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1)
bpl = np.sum(matched_bands)
start = start + scale_sz * (bpl-1)
stop = start + scale_sz * bpl
sigma_ret[1+i*(bpl):1+(i+1)*bpl] = _sigma_mad(wavelet_coeffs[start:stop])
start = stop
elif sigma_est == "global":
sigma_ret *= _sigma_mad(wavelet_coeffs[-np.prod(coeffs_shape[-1]):])
sigma_ret[0] = np.NaN
return sigma_ret
[docs]def wavelet_threshold_estimate(
wavelet_coeffs,
coeffs_shape,
thresh_range="global",
sigma_range="global",
thresh_estimation="hybrid-sure"
):
"""Estimate wavelet coefficient thresholds.
Notes that no threshold will be estimate for the coarse scale.
Parameters
----------
wavelet_coeffs: numpy.ndarray
flatten array of wavelet coefficient, typically returned by ``WaveletN.op``
coeffs_shape: list
List of tuple representing the shape of each subbands.
Typically accessible by WaveletN.coeffs_shape
thresh_range: str. default "global"
Defines on which data range to estimate thresholds.
Either "band", "scale", or "global"
sigma_range: str, default "global"
Defines on which data range to estimate thresholds.
Either "band", "scale", or "global"
thresh_estimation: str, default "hybrid-sure"
Name of the threshold estimation method.
Available are "sure", "hybrid-sure", "universal"
Returns
-------
numpy.ndarray
array of threshold for each wavelet coefficient.
"""
weights = np.ones(wavelet_coeffs.shape)
weights[:np.prod(coeffs_shape[0])] = 0
# Estimate the noise std on the specific range.
sigma_bands = wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_range)
# compute the threshold on each specific range.
start = np.prod(coeffs_shape[0])
stop = start
ts = []
if thresh_range == "global":
weights =sigma_bands[-1] * _thresh_select(
wavelet_coeffs[-np.prod(coeffs_shape[-1]):] / sigma_bands[-1],
thresh_estimation
)
elif thresh_range == "band":
for i in range(1, len(coeffs_shape)):
stop = start + np.prod(coeffs_shape[i])
t = sigma_bands[i] * _thresh_select(
wavelet_coeffs[start:stop] / sigma_bands[i],
thresh_estimation
)
ts.append(t)
weights[start:stop] = t
start = stop
elif thresh_range == "scale":
start = np.prod(coeffs_shape[0])
start_hh = start
for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)):
scale_sz = np.prod(scale_shape)
matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1)
band_per_scale = np.sum(matched_bands)
start_hh = start + scale_sz * (band_per_scale-1)
stop = start + scale_sz * band_per_scale
t = sigma_bands[i+1] * _thresh_select(
wavelet_coeffs[start_hh:stop] / sigma_bands[i+1],
thresh_estimation
)
ts.append(t)
weights[start:stop] = t
start = stop
return weights
[docs]class AutoWeightedSparseThreshold(SparseThreshold):
"""Automatic Weighting of sparse coefficients.
This proximty automatically determines the threshold for Sparse (e.g. Wavelet based)
coefficients.
The weight are computed on first call, and updated on every ``update_period`` call.
Note that the coarse/approximation scale will not be thresholded.
Parameters
----------
coeffs_shape: list of tuple
list of shape for the subbands.
linear: LinearOperator
Required for cost estimation.
update_period: int
Estimation of the weight update period.
threshold_estimation: str
threshold estimation method. Available are "sure", "hybrid-sure" and "universal"
thresh_range: str
threshold range of estimation. Available are "global", "scale" and "band"
sigma_range: str
noise std range of estimation. Available are "global", "scale" and "band"
thresh_type: str
"hard" or "soft" thresholding.
"""
def __init__(self, coeffs_shape, linear=Identity(), update_period=0,
sigma_range="global",
thresh_range="global",
threshold_estimation="sure",
threshold_scaler=1.0,
**kwargs):
self._n_op_calls = 0
self.cf_shape = coeffs_shape
self._update_period = update_period
if thresh_range not in ["bands", "scale", "global"]:
raise ValueError("Unsupported threshold range.")
if sigma_range not in ["bands", "scale", "global"]:
raise ValueError("Unsupported sigma estimation method.")
if threshold_estimation not in ["sure", "hybrid-sure", "universal", "bayes"]:
raise ValueError("Unsupported threshold estimation method.")
self._sigma_range = sigma_range
self._thresh_range = thresh_range
self._thresh_estimation = threshold_estimation
self._thresh_scale = threshold_scaler
weights_init = np.zeros(np.sum(np.prod(coeffs_shape, axis=-1)))
super().__init__(weights=weights_init,
linear=linear,
**kwargs)
[docs] def _auto_thresh(self, input_data):
"""Compute the best weights for the input_data.
Parameters
----------
input_data: numpy.ndarray
Array of sparse coefficient.
See Also
--------
wavelet_threshold_estimate
"""
weights = wavelet_threshold_estimate(
input_data,
self.cf_shape,
thresh_range=self._thresh_range,
sigma_range=self._sigma_range,
thresh_estimation=self._thresh_estimation,
)
if callable(self._thresh_scale):
weights = self._thresh_scale(weights, self._n_op_calls)
else:
weights *= self._thresh_scale
return weights
[docs] def _op_method(self, input_data, extra_factor=1.0):
"""Operator.
This method returns the input data thresholded by the weights.
The weights are computed using the universal threshold rule.
Parameters
----------
input_data : numpy.ndarray
Input data array
extra_factor : float
Additional multiplication factor (default is ``1.0``)
Returns
-------
numpy.ndarray
Thresholded data
"""
if self._update_period == 0 and self._n_op_calls == 0:
self.weights = self._auto_thresh(input_data)
if self._update_period != 0 and self._n_op_calls % self._update_period == 0:
self.weights = self._auto_thresh(input_data)
self._n_op_calls += 1
return super()._op_method(input_data, extra_factor=extra_factor)