"""Fourier operators for online reconstructions."""
import numpy as np
import scipy as sp
from .base import FourierOperatorBase
[docs]class ColumnFFT(FourierOperatorBase):
    """
    Fourier operator optimized to compute the 2D FFT + selection of various line of the kspace.
    The FFT will be normalized in a symmetric way.
    Currently work only in 2D or stack of 2D.
    Attributes:
    -----------
    shape: tuple of int
        shape of the image (not necessarly a square matrix).
    n_coils: int, default 1
        Number of coils used to acquire the signal in case of multiarray
        receiver coils acquisition. If n_coils > 1, data shape must be
        [n_coils, Nx, Ny, NZ]
    n_jobs: int, default 1
        Number of parallel workers to use for fourier computation
    mask: int
        The column of the kspace which is kept.
    Notes:
    ------
    This Operator performs a 1D FFT on the column axis of the provided data
    and then perform a classical DFT operation (there is only one frequency to compute).
    This method is faster and cheaper than the regular 2D FFT+mask operation.
    """
    def __init__(self, shape, line_index=0, n_coils=1):
        """Initilize the 'FFT' class.
        Parameters:
        ----------
        shape: tuple of int
            shape of the image (not necessarly a square matrix).
        n_coils: int, default 1
            Number of coils used to acquire the signal in case of
            multiarray receiver coils acquisition. If n_coils > 1,
            data shape must be equal to [n_coils, Nx, Ny, NZ]
        line_index: int
            The index of the column onto the line_axis of the kspace
        n_jobs: int, default 1
            Number of parallel workers to use for fourier computation
            All cores are used if -1
        package: str
            The plateform on which to run the computation. can be either 'numpy', 'numba', 'cupy'
        """
        self.shape = shape
        if n_coils <= 0:
            n_coils = 1
        self.n_coils = n_coils
        self._exp_f = np.zeros(shape[1], dtype=complex)
        self._exp_b = np.zeros(shape[1], dtype=complex)
        self._mask = line_index
    @property
    def mask(self):
        """Return the column index of the mask."""
        return self._mask
    @mask.setter
    def mask(self, val: int, shift=True):
        """Set the mask vaule and update the frequencies vectors use for the DFT.
        Parameters
        ----------
        val: int
            Index of the column considered for the masked fft.
        shift: bool, optional
            If true, the frequency is shifted (0 = center)
        Raises
        ------
        IndexError: if the column index is not in range.
        Notes
        -----
        the vector for  forward frequencies is defined as:
        .. math:: u_i  = \frac{1}{\sqrt{N}} * exp(-2\pi ji/N)
        similarly for the adjoint operation:
        .. math:: v_i = u_i^* = \frac{1}{\sqrt{N}} * exp(-2\pi ji/N)
        """
        if shift:
            val_shift = (self.shape[1] // 2 + val) % self.shape[1]
        if val >= self.shape[1]:
            raise IndexError("Index out of range")
        self._mask = val
        cos = np.cos(2 * np.pi * val_shift / self.shape[1])
        sin = np.sin(2 * np.pi * val_shift / self.shape[1])
        exp_f = cos - 1j * sin
        exp_b = cos + 1j * sin
        self._exp_f = (1 / np.sqrt(self.shape[1])) * (exp_f ** np.arange(self.shape[1]))
        self._exp_b = (1 / np.sqrt(self.shape[1])) * (exp_b ** np.arange(self.shape[1]))
[docs]    def op(self, img):
        """Compute the masked 2D Fourier transform of a 2d or 3D image.
        Parameters
        ----------
        img: numpy.ndarray
            input ND array with the same shape as the mask. For multichannel
            images the coils dimension is put first
        Returns
        -------
        x: numpy.ndarray
            masked Fourier transform of the input image. For multichannel
            images the coils dimension is put first
        """
        return sp.fft.ifftshift(
            sp.fft.fft(
                np.dot(sp.fft.fftshift(img, axes=[-1, -2]), self._exp_f),
                axis=-1,
                norm="ortho",
            ),
            axes=[-1],
        ) 
[docs]    def adj_op(self, x):
        """Compute inverse masked Fourier transform of a ND image.
        Parameters
        ----------
        x: numpy.ndarray
            masked Fourier transform data. For multichannel
            images the coils dimension is put first
        Returns
        -------
        img: numpy.ndarray
            inverse ND discrete Fourier transform of the input coefficients.
            For multichannel images the coils dimension is put first
        """
        
        return sp.fft.fftshift(
            np.multiply.outer(
                sp.fft.ifft(sp.fft.ifftshift(x, axes=[-1]), axis=-1, norm="ortho"),
                self._exp_b,
            ),
            axes=[-1, -2],
        )