Source code for etomo.operators.linear.HOTV

"""
TV and HOTV class
"""
import numpy as np
import scipy
from .base import LinearBase


[docs]class HOTV(LinearBase): """ The HOTV computation class for 2D image decomposition .. note:: At the moment, assumed that the image is square """ def __init__(self, img_shape, order=1, **kwargs): """ Initialize the 'HOTV' class. Parameters ---------- img_shape: tuple(int) image dimensions order: int, optional order of the differential operator used for the HOTV computation kwargs: Keyword arguments for LinearBase initialization """ super().__init__(**kwargs) assert (img_shape[0] == img_shape[1]) self.order = order self.img_size = img_shape[0] self.filter = np.zeros((order + 1, 1)) for k in range(order + 1): self.filter[k] = (-1) ** (order - k) * scipy.special.binom(order, k) offsets_x = np.arange(order + 1) offsets_y = self.img_size * np.arange(order + 1) shape = (self.img_size ** 2,) * 2 sparse_mat_x = scipy.sparse.diags(self.filter, offsets=offsets_x, shape=shape) sparse_mat_y = scipy.sparse.diags(self.filter, offsets=offsets_y, shape=shape) self.op_matrix = scipy.sparse.vstack([sparse_mat_x, sparse_mat_y])
[docs] def _op(self, data): """ Define the HOTV operator. This method returns the input data convolved with the HOTV filter. Parameters ---------- data: numpy.ndarray((m', m')) input 2D data array. Returns ------- coeffs: numpy.ndarray((2 * m' * m')) the variation values. """ return self.op_matrix * (data.flatten())
[docs] def _adj_op(self, coeffs): """ Define the HOTV adjoint operator. This method returns the adjoint of HOTV computed image. Parameters ---------- coeffs: numpy.ndarray((2 * m' * m')) the HOTV coefficients. Returns ------- data: numpy.ndarray((m', m')) the reconstructed data. """ return np.reshape(self.op_matrix.T * coeffs, (self.img_size, self.img_size))
[docs] def l2norm(self, *args): """ L2 norm can be computed analytically """ return np.sqrt(2 * np.sum(self.filter ** 2))
def __str__(self): return ('HOTV, order ' + str(self.order))
[docs]class HOTV_3D(LinearBase): """ The HOTV computation class for 3D image decomposition .. note:: At the moment, assumed that the image is square in x-y directions """ def __init__(self, img_shape, nb_slices, order=1, **kwargs): """ Initialize the 'HOTV' class. Parameters ---------- img_shape: tuple(int) image dimensions (assuming that the image is square) nb_slices: int number of slices in the 3D reconstructed image order: int, default is 1 order of the differential operator used for the HOTV computation kwargs: Keyword arguments for LinearBase initialization """ super().__init__(**kwargs) # assert (img_shape[0] == img_shape[1]) self.order = order self.img_size = img_shape[0] self.nb_slices = nb_slices self.filter = np.zeros((order + 1, 1)) for k in range(order + 1): self.filter[k] = (-1) ** (order - k) * scipy.special.binom(order, k) offsets_x = np.arange(order + 1) offsets_y = self.img_size * np.arange(order + 1) offsets_z = (self.img_size * self.img_size) * np.arange(order + 1) shape = ((self.img_size ** 2) * self.nb_slices,) * 2 sparse_mat_x = scipy.sparse.diags(self.filter, offsets=offsets_x, shape=shape) sparse_mat_y = scipy.sparse.diags(self.filter, offsets=offsets_y, shape=shape) sparse_mat_z = scipy.sparse.diags(self.filter, offsets=offsets_z, shape=shape) self.op_matrix = scipy.sparse.vstack( [sparse_mat_x, sparse_mat_y, sparse_mat_z])
[docs] def _op(self, data): """ Define the HOTV operator. This method returns the input data convolved with the HOTV filter. Parameters ---------- data: numpy.ndarray((p', m', m')) input 3D data array. Returns ------- coeffs: numpy.ndarray((3 * p' * m' * m')) the variation values. """ return self.op_matrix * (data.flatten())
[docs] def _adj_op(self, coeffs): """ Define the HOTV adjoint operator. This method returns the adjoint of HOTV computed image. Parameters ---------- coeffs: numpy.ndarray((3 * p' * m' * m')) the HOTV coefficients. Returns ------- data: numpy.ndarray((p', m', m')) the reconstructed data. """ return np.reshape(self.op_matrix.T * coeffs, (self.nb_slices, self.img_size, self.img_size))
[docs] def l2norm(self, *args): """ L2 norm can be computed analytically """ return np.sqrt(3 * np.sum(self.filter ** 2))
def __str__(self): return ('HOTV order ' + str(self.order))