"""
TV and HOTV class
"""
import numpy as np
import scipy
from .base import LinearBase
[docs]class HOTV(LinearBase):
""" The HOTV computation class for 2D image decomposition
.. note:: At the moment, assumed that the image is square
"""
def __init__(self, img_shape, order=1, **kwargs):
"""
Initialize the 'HOTV' class.
Parameters
----------
img_shape: tuple(int)
image dimensions
order: int, optional
order of the differential operator used for the HOTV computation
kwargs:
Keyword arguments for LinearBase initialization
"""
super().__init__(**kwargs)
assert (img_shape[0] == img_shape[1])
self.order = order
self.img_size = img_shape[0]
self.filter = np.zeros((order + 1, 1))
for k in range(order + 1):
self.filter[k] = (-1) ** (order - k) * scipy.special.binom(order, k)
offsets_x = np.arange(order + 1)
offsets_y = self.img_size * np.arange(order + 1)
shape = (self.img_size ** 2,) * 2
sparse_mat_x = scipy.sparse.diags(self.filter,
offsets=offsets_x, shape=shape)
sparse_mat_y = scipy.sparse.diags(self.filter,
offsets=offsets_y, shape=shape)
self.op_matrix = scipy.sparse.vstack([sparse_mat_x, sparse_mat_y])
[docs] def _op(self, data):
"""
Define the HOTV operator.
This method returns the input data convolved with the HOTV filter.
Parameters
----------
data: numpy.ndarray((m', m'))
input 2D data array.
Returns
-------
coeffs: numpy.ndarray((2 * m' * m'))
the variation values.
"""
return self.op_matrix * (data.flatten())
[docs] def _adj_op(self, coeffs):
"""
Define the HOTV adjoint operator.
This method returns the adjoint of HOTV computed image.
Parameters
----------
coeffs: numpy.ndarray((2 * m' * m'))
the HOTV coefficients.
Returns
-------
data: numpy.ndarray((m', m'))
the reconstructed data.
"""
return np.reshape(self.op_matrix.T * coeffs,
(self.img_size, self.img_size))
[docs] def l2norm(self, *args):
"""
L2 norm can be computed analytically
"""
return np.sqrt(2 * np.sum(self.filter ** 2))
def __str__(self):
return ('HOTV, order ' + str(self.order))
[docs]class HOTV_3D(LinearBase):
"""
The HOTV computation class for 3D image decomposition
.. note:: At the moment, assumed that the image is square in x-y directions
"""
def __init__(self, img_shape, nb_slices, order=1, **kwargs):
"""
Initialize the 'HOTV' class.
Parameters
----------
img_shape: tuple(int)
image dimensions (assuming that the image is square)
nb_slices: int
number of slices in the 3D reconstructed image
order: int, default is 1
order of the differential operator used for the HOTV computation
kwargs:
Keyword arguments for LinearBase initialization
"""
super().__init__(**kwargs)
# assert (img_shape[0] == img_shape[1])
self.order = order
self.img_size = img_shape[0]
self.nb_slices = nb_slices
self.filter = np.zeros((order + 1, 1))
for k in range(order + 1):
self.filter[k] = (-1) ** (order - k) * scipy.special.binom(order, k)
offsets_x = np.arange(order + 1)
offsets_y = self.img_size * np.arange(order + 1)
offsets_z = (self.img_size * self.img_size) * np.arange(order + 1)
shape = ((self.img_size ** 2) * self.nb_slices,) * 2
sparse_mat_x = scipy.sparse.diags(self.filter,
offsets=offsets_x, shape=shape)
sparse_mat_y = scipy.sparse.diags(self.filter,
offsets=offsets_y, shape=shape)
sparse_mat_z = scipy.sparse.diags(self.filter,
offsets=offsets_z, shape=shape)
self.op_matrix = scipy.sparse.vstack(
[sparse_mat_x, sparse_mat_y, sparse_mat_z])
[docs] def _op(self, data):
"""
Define the HOTV operator.
This method returns the input data convolved with the HOTV filter.
Parameters
----------
data: numpy.ndarray((p', m', m'))
input 3D data array.
Returns
-------
coeffs: numpy.ndarray((3 * p' * m' * m'))
the variation values.
"""
return self.op_matrix * (data.flatten())
[docs] def _adj_op(self, coeffs):
"""
Define the HOTV adjoint operator.
This method returns the adjoint of HOTV computed image.
Parameters
----------
coeffs: numpy.ndarray((3 * p' * m' * m'))
the HOTV coefficients.
Returns
-------
data: numpy.ndarray((p', m', m'))
the reconstructed data.
"""
return np.reshape(self.op_matrix.T * coeffs,
(self.nb_slices, self.img_size, self.img_size))
[docs] def l2norm(self, *args):
"""
L2 norm can be computed analytically
"""
return np.sqrt(3 * np.sum(self.filter ** 2))
def __str__(self):
return ('HOTV order ' + str(self.order))