Source code for mri.operators.linear.wavelet

# -*- coding: utf-8 -*-
##########################################################################
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""Provide linears operators classes adapted to MRI reconstruction algorithms."""


import warnings

import joblib
import numpy as np
import pysap
from joblib import Parallel, delayed
from modopt.signal.wavelet import filter_convolve, get_mr_filters
from pysap.base.utils import flatten, unflatten
from pysap.utils import wavelist

from ..base import OperatorBase


[docs]class WaveletN(OperatorBase): """ 2D and 3D wavelet transform class. Initialize the 'WaveletN' class. Parameters ---------- wavelet_name: str the wavelet name to be used during the decomposition. nb_scales: int, default 4 the number of scales in the decomposition. n_coils: int, default 1 the number of coils for multichannel reconstruction n_jobs: int, default 1 the number of cores to use for multichannel. backend: str, default "threading" the backend to use for parallel multichannel linear operation. verbose: int, default 0 the verbosity level. Attributes ---------- nb_scale: int number of scale decomposed in wavelet space. n_jobs: int number of jobs for parallel computation n_coils: int number of coils use f backend: str Backend use for parallel computation verbose: int Verbosity level """ def __init__(self, wavelet_name, nb_scale=4, verbose=0, dim=2, n_coils=1, n_jobs=1, backend="threading", **kwargs): self.nb_scale = nb_scale self.flatten = flatten self.unflatten = unflatten self.n_jobs = n_jobs self.n_coils = n_coils if self.n_coils == 1 and self.n_jobs != 1: print("Making n_jobs = 1 for WaveletN as n_coils = 1") self.n_jobs = 1 self.backend = backend self.verbose = verbose if wavelet_name not in pysap.AVAILABLE_TRANSFORMS: raise ValueError(f"Unknown transformation '{wavelet_name}'.") transform_klass = pysap.load_transform(wavelet_name) self.transform_queue = [] n_proc = self.n_jobs if n_proc < 0: n_proc = joblib.cpu_count() + self.n_jobs + 1 if n_proc > 0: if wavelet_name in wavelist()['isap-2d'] or \ wavelet_name in wavelist()['isap-3d']: warnings.warn("n_jobs is currently unsupported " "for ISAP wavelets, setting n_jobs=1") self.n_jobs = 1 n_proc = 1 # Create transform queue for parallel execution for _ in range(min(n_proc, self.n_coils)): self.transform_queue.append(transform_klass( nb_scale=self.nb_scale, verbose=verbose, dim=dim, **kwargs) ) self.coeffs_shape = None
[docs] def _op(self, data): if isinstance(data, np.ndarray): data = pysap.Image(data=data) # Get the transform from queue transform = self.transform_queue.pop() transform.data = data transform.analysis() coeffs, coeffs_shape = flatten(transform.analysis_data) # Add back the transform to the queue self.transform_queue.append(transform) return coeffs, coeffs_shape
[docs] def op(self, data): """Define the wavelet operator. This method returns the input data convolved with the wavelet filter. Parameters ---------- data: numpy.ndarray or Image input 2D data array. Returns ------- coeffs: numpy.ndarray the wavelet coefficients. """ if self.n_coils > 1: coeffs, self.coeffs_shape = zip( *Parallel( n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose )( delayed(self._op) (data[i]) for i in np.arange(self.n_coils) ) ) coeffs = np.asarray(coeffs) else: coeffs, self.coeffs_shape = self._op(data) return coeffs
[docs] def _adj_op(self, coeffs, coeffs_shape, dtype="array"): """Define the wavelet adjoint operator. This method returns the reconstructed image. Parameters ---------- coeffs: numpy.ndarray the wavelet coefficients. dtype: str, default 'array' if 'array' return the data as a ndarray, otherwise return a pysap.Image. Returns ------- data: numpy.ndarray the reconstructed data. """ # Get the transform from queue transform = self.transform_queue.pop() transform.analysis_data = unflatten(coeffs, coeffs_shape) image = transform.synthesis() # Add back the transform to the queue self.transform_queue.append(transform) if dtype == "array": return image.data return image
[docs] def adj_op(self, coeffs): """Define the wavelet adjoint operator. This method returns the reconstructed image. Parameters ---------- coeffs: numpy.ndarray the wavelet coefficients. Returns ------- data: numpy.ndarray the reconstructed data. """ if self.n_coils > 1: images = Parallel( n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose )( delayed(self._adj_op) (coeffs[i], self.coeffs_shape[i]) for i in np.arange(self.n_coils) ) images = np.asarray(images) else: images = self._adj_op(coeffs, self.coeffs_shape) return images
[docs] def l2norm(self, shape): """Compute the L2 norm. Parameters ---------- shape: tuple The data shape. Returns ------- norm: float The L2 norm. """ # Create fake data shape = np.asarray(shape) shape += shape % 2 fake_data = np.zeros(shape) fake_data[tuple(zip(shape // 2))] = 1 # Call mr_transform data = self.op(fake_data) # Compute the L2 norm return np.linalg.norm(data)
[docs]class WaveletUD2(OperatorBase): """ Wavelet undecimated operator using pysap wrapper. Parameters ---------- wavelet_id: int, default 24 = undecimated (bi-) orthogonal transform ID of wavelet being used nb_scale: int, default 4 the number of scales in the decomposition. multichannel: bool, default False Boolean value to indicate if the incoming data is from multiple-channels n_jobs: int, default 0 Number of CPUs to run on. Only applicable if multichannel=True. backend: 'threading' | 'multiprocessing', default 'threading' Denotes the backend to use for parallel execution across multiple channels. verbose: int, default 0 The verbosity level for Parallel operation from joblib Attributes ---------- _has_run: bool Checks if the get_mr_filters was called already """ def __init__(self, wavelet_id=24, nb_scale=4, n_jobs=1, backend='threading', n_coils=1, verbose=0): self.wavelet_id = wavelet_id self.n_coils = n_coils self.nb_scale = nb_scale self.n_jobs = n_jobs self.backend = backend self.verbose = verbose self._opt = [ f'-t{self.wavelet_id}', f'-n{self.nb_scale}', ] self._has_run = False self.coeffs_shape = None self.transform = None
[docs] def _get_filters(self, shape): """Get the Wavelet coefficients of Delta[0][0]. This function is called only once and later the wavelet coefficients are obtained by convolving these coefficients with input Data Parameters ---------- shape: tuple or array Shape of data on which the filter will be applied. """ self.transform = get_mr_filters( tuple(shape), opt=self._opt, coarse=True, ) self._has_run = True
[docs] def _op(self, data): """Define the wavelet operator for single channel. Returns wavelet coefficients for a single channel Parameters ---------- data: numpy.ndarray or Image input 2D data array. Returns ------- coeffs: numpy.ndarray the wavelet coefficients. """ coefs_real = filter_convolve(data.real, self.transform) coefs_imag = filter_convolve(data.imag, self.transform) coeffs, coeffs_shape = flatten( coefs_real + 1j * coefs_imag) return coeffs, coeffs_shape
[docs] def op(self, data): """Define the wavelet operator. This method returns the input data convolved with the wavelet filter. Parameters ---------- data: numpy.ndarray or Image input 2D data array. Returns ------- coeffs: numpy.ndarray the wavelet coefficients. """ if not self._has_run: if self.n_coils > 1: self._get_filters(list(data.shape)[1:]) else: self._get_filters(data.shape) if self.n_coils > 1: coeffs, self.coeffs_shape = zip(*Parallel(n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose)( delayed(self._op) (data[i]) for i in np.arange(self.n_coils))) coeffs = np.asarray(coeffs) else: coeffs, self.coeffs_shape = self._op(data) return coeffs
[docs] def _adj_op(self, coeffs, coeffs_shape): """Define the wavelet adjoint operator. This method returns the reconstructed image for single channel. Parameters ---------- coeffs: numpy.ndarray the wavelet coefficients. coeffs_shape: numpy.ndarray The shape of coefficients to unflatten before adjoint operation Returns ------- data: numpy.ndarray the reconstructed data. """ data_real = filter_convolve( np.squeeze(unflatten(coeffs.real, coeffs_shape)), self.transform, filter_rot=True) data_imag = filter_convolve( np.squeeze(unflatten(coeffs.imag, coeffs_shape)), self.transform, filter_rot=True) return data_real + 1j * data_imag
[docs] def adj_op(self, coeffs): """Define the wavelet adjoint operator. This method returns the reconstructed image. Parameters ---------- coeffs: numpy.ndarray the wavelet coefficients. Returns ------- data: numpy.ndarray the reconstructed data. """ if not self._has_run: raise RuntimeError( "`op` must be run before `adj_op` to get the data shape", ) if self.n_coils > 1: images = Parallel(n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose)( delayed(self._adj_op) (coeffs[i], self.coeffs_shape[i]) for i in np.arange(self.n_coils)) images = np.asarray(images) else: images = self._adj_op(coeffs, self.coeffs_shape) return images
[docs] def l2norm(self, shape): """Compute the L2 norm. Parameters ---------- shape: uplet the data shape. Returns ------- norm: float the L2 norm. """ # Create fake data shape = np.asarray(shape) shape += shape % 2 fake_data = np.zeros(shape) fake_data[tuple(zip(shape // 2))] = 1 # Call mr_transform data = self.op(fake_data) # Compute the L2 norm return np.linalg.norm(data)