Auto Thresholded cartesian reconstruction#

Author: Chaithya G R / Pierre-Antoine Comby

In this tutorial we will reconstruct an MRI image from the sparse kspace measurements.

Import neuroimaging data#

We use the toy datasets available in pysap, more specifically a 2D brain slice and the cartesian acquisition scheme.

Package import

import matplotlib.pyplot as plt
import numpy as np
from modopt.math.metrics import snr, ssim
from modopt.opt.linear import Identity
# Third party import
from modopt.opt.proximity import SparseThreshold
from mri.operators import FFT, WaveletN
from mri.operators.proximity.weighted import AutoWeightedSparseThreshold
from mri.operators.utils import convert_mask_to_locations
from mri.reconstructors import SingleChannelReconstructor
from pysap.data import get_sample_data

image = get_sample_data('2d-mri')
print(image.data.min(), image.data.max())
image = image.data
image /= np.max(image)
mask = get_sample_data("cartesian-mri-mask")


# Get the locations of the kspace samples
kspace_loc = convert_mask_to_locations(mask.data)
# Generate the subsampled kspace
fourier_op = FFT(mask=mask, shape=image.shape)
kspace_data = fourier_op.op(image)

# Zero order solution
image_rec0 = np.abs(fourier_op.adj_op(kspace_data))

# Calculate SSIM
base_ssim = ssim(image_rec0, image)
print(base_ssim)

Out:

9.5823815e-08 2.518473e-05
0.7228762359967604

POGM optimization#

We now want to refine the zero order solution using an accelerated Proximal Gradient Descent algorithm (FISTA or POGM). The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(wavelet_name="sym8", nb_scales=3)

# Manual tweak of the regularisation parameter
regularizer_op = SparseThreshold(Identity(), 2e-3, thresh_type="soft")
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
# Start Reconstruction
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,
    optimization_alg='pogm',
    num_iterations=100,
    cost_op_kwargs={"cost_interval":None},
    metric_call_period=1,
    metrics = {
        "snr":{
            "metric": snr,
            "mapping": {"x_new":"test"},
            "cst_kwargs": {"ref": image},
            "early_stopping":False,
        },
        "ssim":{
            "metric": ssim,
            "mapping": {"x_new":"test"},
            "cst_kwargs": {"ref": image},
            "early_stopping": False,
        }
    }
)

image_rec = np.abs(x_final)
# image_rec.show()
# Calculate SSIM
recon_ssim = ssim(image_rec, image)
recon_snr= snr(image_rec, image)

print('The Reconstruction SSIM is : ' + str(recon_ssim))
print('The Reconstruction SNR is : ' + str(recon_snr))

Out:

WARNING: Making input data immutable.
Lipschitz constant is 1.1
The lipschitz constraint is satisfied
WARNING: Making input data immutable.
 - mu:  0.002
 - lipschitz constant:  1.1
 - data:  (512, 512)
 - wavelet:  <mri.operators.linear.wavelet.WaveletN object at 0x7f3bd11c5870> - 4
 - max iterations:  100
 - image variable shape:  (1, 512, 512)
----------------------------------------
Starting optimization...

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:11,  8.66it/s]
  2%|2         | 2/100 [00:00<00:11,  8.60it/s]
  3%|3         | 3/100 [00:00<00:11,  8.54it/s]
  4%|4         | 4/100 [00:00<00:11,  8.54it/s]
  5%|5         | 5/100 [00:00<00:11,  8.56it/s]
  6%|6         | 6/100 [00:00<00:10,  8.58it/s]
  7%|7         | 7/100 [00:00<00:10,  8.57it/s]
  8%|8         | 8/100 [00:00<00:10,  8.60it/s]
  9%|9         | 9/100 [00:01<00:10,  8.60it/s]
 10%|#         | 10/100 [00:01<00:10,  8.58it/s]
 11%|#1        | 11/100 [00:01<00:10,  8.58it/s]
 12%|#2        | 12/100 [00:01<00:10,  8.53it/s]
 13%|#3        | 13/100 [00:01<00:10,  8.53it/s]
 14%|#4        | 14/100 [00:01<00:10,  8.55it/s]
 15%|#5        | 15/100 [00:01<00:09,  8.56it/s]
 16%|#6        | 16/100 [00:01<00:09,  8.56it/s]
 17%|#7        | 17/100 [00:01<00:09,  8.57it/s]
 18%|#8        | 18/100 [00:02<00:09,  8.59it/s]
 19%|#9        | 19/100 [00:02<00:09,  8.60it/s]
 20%|##        | 20/100 [00:02<00:09,  8.59it/s]
 21%|##1       | 21/100 [00:02<00:09,  8.60it/s]
 22%|##2       | 22/100 [00:02<00:09,  8.59it/s]
 23%|##3       | 23/100 [00:02<00:08,  8.61it/s]
 24%|##4       | 24/100 [00:02<00:08,  8.60it/s]
 25%|##5       | 25/100 [00:02<00:08,  8.58it/s]
 26%|##6       | 26/100 [00:03<00:08,  8.58it/s]
 27%|##7       | 27/100 [00:03<00:08,  8.57it/s]
 28%|##8       | 28/100 [00:03<00:08,  8.61it/s]
 29%|##9       | 29/100 [00:03<00:08,  8.60it/s]
 30%|###       | 30/100 [00:03<00:08,  8.58it/s]
 31%|###1      | 31/100 [00:03<00:08,  8.58it/s]
 32%|###2      | 32/100 [00:03<00:07,  8.59it/s]
 33%|###3      | 33/100 [00:03<00:07,  8.53it/s]
 34%|###4      | 34/100 [00:03<00:07,  8.54it/s]
 35%|###5      | 35/100 [00:04<00:07,  8.55it/s]
 36%|###6      | 36/100 [00:04<00:07,  8.59it/s]
 37%|###7      | 37/100 [00:04<00:07,  8.58it/s]
 38%|###8      | 38/100 [00:04<00:07,  8.59it/s]
 39%|###9      | 39/100 [00:04<00:07,  8.60it/s]
 40%|####      | 40/100 [00:04<00:06,  8.60it/s]
 41%|####1     | 41/100 [00:04<00:06,  8.65it/s]
 42%|####2     | 42/100 [00:04<00:06,  8.60it/s]
 43%|####3     | 43/100 [00:05<00:06,  8.59it/s]
 44%|####4     | 44/100 [00:05<00:06,  8.57it/s]
 45%|####5     | 45/100 [00:05<00:06,  8.59it/s]
 46%|####6     | 46/100 [00:05<00:06,  8.61it/s]
 47%|####6     | 47/100 [00:05<00:06,  8.60it/s]
 48%|####8     | 48/100 [00:05<00:06,  8.59it/s]
 49%|####9     | 49/100 [00:05<00:05,  8.53it/s]
 50%|#####     | 50/100 [00:05<00:05,  8.54it/s]
 51%|#####1    | 51/100 [00:05<00:05,  8.60it/s]
 52%|#####2    | 52/100 [00:06<00:05,  8.62it/s]
 53%|#####3    | 53/100 [00:06<00:05,  8.62it/s]
 54%|#####4    | 54/100 [00:06<00:05,  8.61it/s]
 55%|#####5    | 55/100 [00:06<00:05,  8.60it/s]
 56%|#####6    | 56/100 [00:06<00:05,  8.61it/s]
 57%|#####6    | 57/100 [00:06<00:04,  8.61it/s]
 58%|#####8    | 58/100 [00:06<00:04,  8.62it/s]
 59%|#####8    | 59/100 [00:06<00:04,  8.61it/s]
 60%|######    | 60/100 [00:06<00:04,  8.60it/s]
 61%|######1   | 61/100 [00:07<00:04,  8.61it/s]
 62%|######2   | 62/100 [00:07<00:04,  8.61it/s]
 63%|######3   | 63/100 [00:07<00:04,  8.62it/s]
 64%|######4   | 64/100 [00:07<00:04,  8.60it/s]
 65%|######5   | 65/100 [00:07<00:04,  8.59it/s]
 66%|######6   | 66/100 [00:07<00:03,  8.60it/s]
 67%|######7   | 67/100 [00:07<00:03,  8.58it/s]
 68%|######8   | 68/100 [00:07<00:03,  8.58it/s]
 69%|######9   | 69/100 [00:08<00:03,  8.61it/s]
 70%|#######   | 70/100 [00:08<00:03,  8.61it/s]
 71%|#######1  | 71/100 [00:08<00:03,  8.60it/s]
 72%|#######2  | 72/100 [00:08<00:03,  8.61it/s]
 73%|#######3  | 73/100 [00:08<00:03,  8.62it/s]
 74%|#######4  | 74/100 [00:08<00:03,  8.61it/s]
 75%|#######5  | 75/100 [00:08<00:02,  8.58it/s]
 76%|#######6  | 76/100 [00:08<00:02,  8.58it/s]
 77%|#######7  | 77/100 [00:08<00:02,  8.60it/s]
 78%|#######8  | 78/100 [00:09<00:02,  8.60it/s]
 79%|#######9  | 79/100 [00:09<00:02,  8.59it/s]
 80%|########  | 80/100 [00:09<00:02,  8.59it/s]
 81%|########1 | 81/100 [00:09<00:02,  8.61it/s]
 82%|########2 | 82/100 [00:09<00:02,  8.63it/s]
 83%|########2 | 83/100 [00:09<00:01,  8.65it/s]
 84%|########4 | 84/100 [00:09<00:01,  8.64it/s]
 85%|########5 | 85/100 [00:09<00:01,  8.65it/s]
 86%|########6 | 86/100 [00:10<00:01,  8.63it/s]
 87%|########7 | 87/100 [00:10<00:01,  8.65it/s]
 88%|########8 | 88/100 [00:10<00:01,  8.63it/s]
 89%|########9 | 89/100 [00:10<00:01,  8.61it/s]
 90%|######### | 90/100 [00:10<00:01,  8.61it/s]
 91%|#########1| 91/100 [00:10<00:01,  8.59it/s]
 92%|#########2| 92/100 [00:10<00:00,  8.60it/s]
 93%|#########3| 93/100 [00:10<00:00,  8.56it/s]
 94%|#########3| 94/100 [00:10<00:00,  8.49it/s]
 95%|#########5| 95/100 [00:11<00:00,  8.48it/s]
 96%|#########6| 96/100 [00:11<00:00,  8.52it/s]
 97%|#########7| 97/100 [00:11<00:00,  8.52it/s]
 98%|#########8| 98/100 [00:11<00:00,  8.56it/s]
 99%|#########9| 99/100 [00:11<00:00,  8.59it/s]
100%|##########| 100/100 [00:11<00:00,  8.58it/s]
100%|##########| 100/100 [00:11<00:00,  8.59it/s]
 - final iteration number:  100
 - final log10 cost value:  6.0
 - converged:  False
Done.
Execution time:  11.643497677054256  seconds
----------------------------------------
The Reconstruction SSIM is : 0.8597496784357197
The Reconstruction SNR is : 15.947746555163501

Threshold estimation using SURE#

_w = None

def static_weight(w, idx):
    print(np.unique(w))
    return w

# Setup the operators
linear_op = WaveletN(wavelet_name="sym8", nb_scale=3,padding_mode="periodization")
coeffs = linear_op.op(image_rec0)
print(linear_op.coeffs_shape)

# Here we don't manually setup the regularisation weights, but use statistics on the wavelet details coefficients

regularizer_op = AutoWeightedSparseThreshold(
    linear_op.coeffs_shape, linear=Identity(),
    update_period=0, # the weight is updated only once.
    sigma_range="global",
    thresh_range="global",
    threshold_estimation="sure",
    thresh_type="soft",
)
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
# Start Reconstruction
x_final, costs, metrics2 = reconstructor.reconstruct(
    kspace_data=kspace_data,
    optimization_alg='pogm',
    num_iterations=100,
    metric_call_period=1,
    cost_op_kwargs={"cost_interval":None},
    metrics = {
         "snr":{
            "metric": snr,
            "mapping": {"x_new":"test"},
            "cst_kwargs": {"ref": image},
            "early_stopping":False,
        },
        "ssim":{
            "metric": ssim,
            "mapping": {"x_new":"test"},
            "cst_kwargs": {"ref": image},
            "early_stopping": False,
        },
        "cost_grad":{
            "metric": lambda x: reconstructor.gradient_op.cost(linear_op.op(x)),
            "mapping": {"x_new":"x"},
            "cst_kwargs": {},
            "early_stopping": False,
        },
        "cost_prox":{
            "metric": lambda x: reconstructor.prox_op.cost(linear_op.op(x)),
            "mapping": {"x_new":"x"},
            "cst_kwargs": {},
            "early_stopping": False,
        }
    }
)
image_rec2 = np.abs(x_final)
# image_rec.show()
# Calculate SSIM
recon_ssim2 = ssim(image_rec2, image)
recon_snr2 = snr(image_rec2, image)

print('The Reconstruction SSIM is : ' + str(recon_ssim2))
print('The Reconstruction SNR is : ' + str(recon_snr2))

plt.subplot(121)
plt.plot(metrics["snr"]["time"], metrics["snr"]["values"], label="pogm classic")
plt.plot(metrics2["snr"]["time"], metrics2["snr"]["values"], label="pogm sure global")
plt.ylabel("snr")
plt.xlabel("time")
plt.legend()
plt.subplot(122)
plt.plot(metrics["ssim"]["time"], metrics["ssim"]["values"])
plt.plot(metrics2["ssim"]["time"], metrics2["ssim"]["values"])
plt.ylabel("ssim")
plt.xlabel("time")
plt.figure()
plt.subplot(121)
plt.plot(metrics["snr"]["index"], metrics["snr"]["values"])
plt.plot(metrics2["snr"]["index"], metrics2["snr"]["values"])
plt.ylabel("snr")
plt.subplot(122)
plt.plot(metrics["ssim"]["index"], metrics["ssim"]["values"])
plt.plot(metrics2["ssim"]["index"], metrics2["ssim"]["values"])
plt.show()
  • cartesian reconstruction auto threshold
  • cartesian reconstruction auto threshold

Out:

[(64, 64), (64, 64), (64, 64), (64, 64), (128, 128), (128, 128), (128, 128), (256, 256), (256, 256), (256, 256)]
WARNING: Making input data immutable.
Lipschitz constant is 1.1000001311302186
The lipschitz constraint is satisfied
 - mu:  [0. 0. 0. ... 0. 0. 0.]
 - lipschitz constant:  1.1000001311302186
 - data:  (512, 512)
 - wavelet:  <mri.operators.linear.wavelet.WaveletN object at 0x7f3a4e93b910> - 3
 - max iterations:  100
 - image variable shape:  (1, 512, 512)
----------------------------------------
Starting optimization...

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:17,  5.51it/s]
  2%|2         | 2/100 [00:00<00:17,  5.53it/s]
  3%|3         | 3/100 [00:00<00:17,  5.58it/s]
  4%|4         | 4/100 [00:00<00:17,  5.60it/s]
  5%|5         | 5/100 [00:00<00:16,  5.60it/s]
  6%|6         | 6/100 [00:01<00:16,  5.63it/s]
  7%|7         | 7/100 [00:01<00:16,  5.64it/s]
  8%|8         | 8/100 [00:01<00:16,  5.62it/s]
  9%|9         | 9/100 [00:01<00:16,  5.62it/s]
 10%|#         | 10/100 [00:01<00:16,  5.61it/s]
 11%|#1        | 11/100 [00:01<00:15,  5.62it/s]
 12%|#2        | 12/100 [00:02<00:15,  5.62it/s]
 13%|#3        | 13/100 [00:02<00:15,  5.60it/s]
 14%|#4        | 14/100 [00:02<00:15,  5.63it/s]
 15%|#5        | 15/100 [00:02<00:15,  5.62it/s]
 16%|#6        | 16/100 [00:02<00:14,  5.61it/s]
 17%|#7        | 17/100 [00:03<00:14,  5.62it/s]
 18%|#8        | 18/100 [00:03<00:14,  5.65it/s]
 19%|#9        | 19/100 [00:03<00:14,  5.63it/s]
 20%|##        | 20/100 [00:03<00:14,  5.63it/s]
 21%|##1       | 21/100 [00:03<00:13,  5.66it/s]
 22%|##2       | 22/100 [00:03<00:13,  5.62it/s]
 23%|##3       | 23/100 [00:04<00:13,  5.63it/s]
 24%|##4       | 24/100 [00:04<00:13,  5.63it/s]
 25%|##5       | 25/100 [00:04<00:13,  5.61it/s]
 26%|##6       | 26/100 [00:04<00:13,  5.60it/s]
 27%|##7       | 27/100 [00:04<00:13,  5.60it/s]
 28%|##8       | 28/100 [00:04<00:12,  5.61it/s]
 29%|##9       | 29/100 [00:05<00:12,  5.61it/s]
 30%|###       | 30/100 [00:05<00:12,  5.62it/s]
 31%|###1      | 31/100 [00:05<00:12,  5.61it/s]
 32%|###2      | 32/100 [00:05<00:12,  5.62it/s]
 33%|###3      | 33/100 [00:05<00:12,  5.58it/s]
 34%|###4      | 34/100 [00:06<00:11,  5.56it/s]
 35%|###5      | 35/100 [00:06<00:11,  5.56it/s]
 36%|###6      | 36/100 [00:06<00:11,  5.55it/s]
 37%|###7      | 37/100 [00:06<00:11,  5.55it/s]
 38%|###8      | 38/100 [00:06<00:11,  5.56it/s]
 39%|###9      | 39/100 [00:06<00:10,  5.58it/s]
 40%|####      | 40/100 [00:07<00:10,  5.57it/s]
 41%|####1     | 41/100 [00:07<00:10,  5.57it/s]
 42%|####2     | 42/100 [00:07<00:10,  5.58it/s]
 43%|####3     | 43/100 [00:07<00:10,  5.60it/s]
 44%|####4     | 44/100 [00:07<00:10,  5.60it/s]
 45%|####5     | 45/100 [00:08<00:09,  5.60it/s]
 46%|####6     | 46/100 [00:08<00:09,  5.61it/s]
 47%|####6     | 47/100 [00:08<00:09,  5.61it/s]
 48%|####8     | 48/100 [00:08<00:09,  5.61it/s]
 49%|####9     | 49/100 [00:08<00:09,  5.63it/s]
 50%|#####     | 50/100 [00:08<00:08,  5.62it/s]
 51%|#####1    | 51/100 [00:09<00:08,  5.62it/s]
 52%|#####2    | 52/100 [00:09<00:08,  5.63it/s]
 53%|#####3    | 53/100 [00:09<00:08,  5.63it/s]
 54%|#####4    | 54/100 [00:09<00:08,  5.63it/s]
 55%|#####5    | 55/100 [00:09<00:08,  5.60it/s]
 56%|#####6    | 56/100 [00:09<00:07,  5.60it/s]
 57%|#####6    | 57/100 [00:10<00:07,  5.60it/s]
 58%|#####8    | 58/100 [00:10<00:07,  5.58it/s]
 59%|#####8    | 59/100 [00:10<00:07,  5.58it/s]
 60%|######    | 60/100 [00:10<00:07,  5.59it/s]
 61%|######1   | 61/100 [00:10<00:06,  5.59it/s]
 62%|######2   | 62/100 [00:11<00:06,  5.59it/s]
 63%|######3   | 63/100 [00:11<00:06,  5.62it/s]
 64%|######4   | 64/100 [00:11<00:06,  5.60it/s]
 65%|######5   | 65/100 [00:11<00:06,  5.61it/s]
 66%|######6   | 66/100 [00:11<00:06,  5.60it/s]
 67%|######7   | 67/100 [00:11<00:05,  5.60it/s]
 68%|######8   | 68/100 [00:12<00:05,  5.60it/s]
 69%|######9   | 69/100 [00:12<00:05,  5.60it/s]
 70%|#######   | 70/100 [00:12<00:05,  5.63it/s]
 71%|#######1  | 71/100 [00:12<00:05,  5.62it/s]
 72%|#######2  | 72/100 [00:12<00:04,  5.63it/s]
 73%|#######3  | 73/100 [00:13<00:04,  5.64it/s]
 74%|#######4  | 74/100 [00:13<00:04,  5.64it/s]
 75%|#######5  | 75/100 [00:13<00:04,  5.64it/s]
 76%|#######6  | 76/100 [00:13<00:04,  5.64it/s]
 77%|#######7  | 77/100 [00:13<00:04,  5.62it/s]
 78%|#######8  | 78/100 [00:13<00:03,  5.62it/s]
 79%|#######9  | 79/100 [00:14<00:03,  5.62it/s]
 80%|########  | 80/100 [00:14<00:03,  5.62it/s]
 81%|########1 | 81/100 [00:14<00:03,  5.59it/s]
 82%|########2 | 82/100 [00:14<00:03,  5.60it/s]
 83%|########2 | 83/100 [00:14<00:03,  5.59it/s]
 84%|########4 | 84/100 [00:14<00:02,  5.57it/s]
 85%|########5 | 85/100 [00:15<00:02,  5.59it/s]
 86%|########6 | 86/100 [00:15<00:02,  5.59it/s]
 87%|########7 | 87/100 [00:15<00:02,  5.57it/s]
 88%|########8 | 88/100 [00:15<00:02,  5.55it/s]
 89%|########9 | 89/100 [00:15<00:01,  5.58it/s]
 90%|######### | 90/100 [00:16<00:01,  5.59it/s]
 91%|#########1| 91/100 [00:16<00:01,  5.60it/s]
 92%|#########2| 92/100 [00:16<00:01,  5.60it/s]
 93%|#########3| 93/100 [00:16<00:01,  5.59it/s]
 94%|#########3| 94/100 [00:16<00:01,  5.60it/s]
 95%|#########5| 95/100 [00:16<00:00,  5.56it/s]
 96%|#########6| 96/100 [00:17<00:00,  5.56it/s]
 97%|#########7| 97/100 [00:17<00:00,  5.52it/s]
 98%|#########8| 98/100 [00:17<00:00,  5.50it/s]
 99%|#########9| 99/100 [00:17<00:00,  5.51it/s]
100%|##########| 100/100 [00:17<00:00,  5.50it/s]
100%|##########| 100/100 [00:17<00:00,  5.60it/s]
 - final iteration number:  100
 - final log10 cost value:  6.0
 - converged:  False
Done.
Execution time:  17.867400695104152  seconds
----------------------------------------
The Reconstruction SSIM is : 0.8140168579836002
The Reconstruction SNR is : 14.087662909035688
/volatile/Chaithya/actions-runner/_work/pysap/pysap/examples/pysap-mri/cartesian_reconstruction_auto_threshold.py:198: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  plt.show()

Qualitative results#

def my_imshow(ax, img, title):
    ax.imshow(img, cmap="gray")
    ax.set_title(title)
    ax.axis("off")



fig, axs = plt.subplots(2,2)

my_imshow(axs[0,0], image, "Ground Truth")
my_imshow(axs[0,1], abs(image_rec0), f"Zero Order \n SSIM={base_ssim:.4f}")
my_imshow(axs[1,0], abs(image_rec), f"Fista Classic \n SSIM={recon_ssim:.4f}")
my_imshow(axs[1,1], abs(image_rec2), f"Fista Sure \n SSIM={recon_ssim2:.4f}")

fig.tight_layout()
plt.show()
Ground Truth, Zero Order   SSIM=0.7229, Fista Classic   SSIM=0.8597, Fista Sure   SSIM=0.8140

Out:

/volatile/Chaithya/actions-runner/_work/pysap/pysap/examples/pysap-mri/cartesian_reconstruction_auto_threshold.py:219: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  plt.show()

Total running time of the script: ( 0 minutes 33.467 seconds)

Gallery generated by Sphinx-Gallery