Source code for mri.operators.linear.utils

# -*- coding: utf-8 -*-
# pySAP - Copyright (C) CEA, 2017 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# for details.

"""Functions  used for the dictionary learning Compressed Sensing reconstruction."""

# System import
from __future__ import division
import time
import itertools

# Third party import
import numpy as np
import progressbar
from sklearn.utils import check_random_state, gen_batches
from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import extract_patches_2d

[docs]def timer(start, end): """Give duration time between 2 times in hh:mm:ss. Parameters ---------- start: float the starting time. end: float the ending time. Returns ------- str: the duration formated in hh:mm:ss. """ hours, rem = divmod(end-start, 3600) minutes, seconds = divmod(rem, 60) return f"{int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}"
[docs]def min_max_normalize(img): """Center and normalize the given array. Parameters ---------- img: numpy.ndarray Returns ------- ndarray: the center and normalized array. """ img = np.nan_to_num(img) min_img = img.min() max_img = img.max() img = (img - min_img) / (max_img - min_img) return np.nan_to_num(img)
[docs]def extract_patches_from_2d_images(img, patch_shape): """Return flattened patches from the 2d image. Parameters ---------- img: numpy.ndarray of floats, the input 2d image patch_shape: tuple of int, shape of the patches Returns ------- patches: numpy.ndarray of floats, a 2d matrix with dim nb_patches*(patch.shape[0]*patch_shape[1]) """ patches = extract_patches_2d(img, patch_shape) patches = patches.reshape(patches.shape[0], -1) return patches
[docs]def generate_flat_patches(images, patch_size, option='real'): """Generate flat patches the list of images. The generated images can be either the real, imaginary, complex or modulus of the images. Parameters ---------- image: list of list of np.ndarray of float or complex a sublist containing all the images for one subject patch_size: int, width of square patches option: {'real', 'imag', 'abs', 'complex'} Returns ------- flat_patches: list of np.ndarray as a GENERATOR The patches flat and concatained as a list """ patch_shape = (patch_size, patch_size) flat_patches = images[:] for imgs in flat_patches: flat_patches_sub = [] for img in imgs: if option == "abs": image = np.abs(img).astype("float") patches = extract_patches_from_2d_images( min_max_normalize(image), patch_shape) elif option == "real": image = np.real(img) patches = extract_patches_from_2d_images( min_max_normalize(image), patch_shape) elif option == "imag": image = np.imag(img) patches = extract_patches_from_2d_images( min_max_normalize(image), patch_shape) elif option == "complex": patches_r = extract_patches_from_2d_images( min_max_normalize(np.real(img)), patch_shape) patches_i = extract_patches_from_2d_images( min_max_normalize(np.imag(img)), patch_shape) patches = patches_r + 1j * patches_i else: raise ValueError(f"Unsupported option: '{option}'") flat_patches_sub.append(patches) yield flat_patches_sub
[docs]def learn_dictionary(flat_patches_subjects, nb_atoms=100, alpha=1, n_iter=1, fit_algorithm='lars', transform_algorithm='lasso_lars', batch_size=100, n_jobs=1, verbose=1): """Learn the dictionary from a training set. Parameters ---------- flat_patches: generator of 1d array of flat patches (floats) a list per subject nb_atoms: int, number of components of the dictionary (default=100) alpha: float, regulation term (default=1) n_iter: int number of iterations (default=1) fit_algorithm: 'lars' for more details see MiniBatchDictionaryLearning from the sklearn library transform_algorithm: 'lasso_lars', for more details see MiniBatchDictionaryLearning from the sklearn library batch_size: int (default 100), number of patches taken per iteration to fit the model n_jobs: int defaul 6, number of cpu to run the learning verbose: int default1, The level of verbosity Returns ------- dico: MiniBatchDictionaryLearning object """ dico = MiniBatchDictionaryLearning( n_components=nb_atoms, alpha=alpha, n_iter=n_iter, fit_algorithm=fit_algorithm, transform_algorithm=transform_algorithm, n_jobs=n_jobs, verbose=0) rng = check_random_state(0) if verbose == 2: print("Dictionary Learning starting") t_start = time.time() for patches_subject in flat_patches_subjects: patches = list(itertools.chain(*patches_subject)) if verbose == 1: print(f"[info] number of patches of the subject: {len(patches)}") rng.shuffle(patches) batches = gen_batches(len(patches), batch_size) nb_batches = len(patches) // batch_size with progressbar.ProgressBar(max_value=nb_batches, redirect_stdout=True) as bar: for cnt, batch in enumerate(batches): t_start2 = time.time() dico.partial_fit(patches[batch][:1]) duration = time.time() - t_start2 if verbose == 2: print(f"[info] batch time: {duration}") bar.update(cnt) t_end = time.time() if verbose == 1: print(f"[info] dictionary learned in {timer(t_start, t_end)}") return dico