Deep Learning for Iterative Spectral CT Reconstruction: Replacing Statistical Iterations with an Attention-Based U-Net
RMS as a function of time
A time-rms plot which includes
rms bune itt against time
rms water itt against time
rms bone model against time
rms water model against time
import pickle
import matplotlib.pyplot as plt
import numpy as np
objectSize = 32
with open('phantom_results_2.pkl', 'rb') as f:
data = pickle.load(f)
reconstructions = data['reconstructions']
ys = data['ys']
phantoms = data['phantoms']
model_images = data['output_images']
times = data['times']
times_model = data['times_images']
rms_itt_bone = data['rms_reconstructions_bone']
rms_itt_water = data['rms_reconstructions_water']
rms_model_bone = data['rms_models_bone']
rms_model_water = data['rms_models_water']
def rms_error(image1, image2):
if image1.shape != image2.shape:
raise ValueError("Images must have the same dimensions")
# Calculate the squared differences
squared_diff = (image1 - image2) ** 2
# Calculate the mean of the squared differences
mean_squared_diff = np.mean(squared_diff)
# Return the square root of the mean squared difference
return np.sqrt(mean_squared_diff)
# calculate RMS for reconstructed images
rms_reconstructions_bone = []
rms_reconstructions_water = []
rms_models_bone = []
rms_models_water = []
for reconstruction, output_image, phantom in zip(reconstructions, model_images, phantoms):
rms_recon_bone = []
rms_recon_water = []
rms_model_bone = []
rms_model_water = []
phantom_bone = phantom[0].transpose() # Get the ith bone
phantom_water = phantom[1].transpose() # Get the ith water
# get first image from reconstruction
_, nMats, nIterates = reconstruction.shape
images = reconstruction.reshape((objectSize, objectSize, objectSize, nMats, nIterates), order = 'F')
for i in range(len(images[0, 0, 0, 0, :])):
image_bone = images[:, :, :, 0, i]
image_water = images[:, :, :, 1, i]
rms_recon_bone.append(rms_error(image_bone, phantom_bone))
rms_recon_water.append(rms_error(image_water, phantom_water))
rms_reconstructions_bone.append(rms_recon_bone)
rms_reconstructions_water.append(rms_recon_water)
for image_bone, image_water in output_image:
rms_model_bone.append(rms_error(image_bone, phantom_bone))
rms_model_water.append(rms_error(image_water, phantom_water))
rms_models_bone.append(rms_model_bone)
rms_models_water.append(rms_model_water)
# we only want te keep the values for the models which are inside the time of the recon
for i in range(len(times_model)):
times_model[i] = [t for t in times_model[i] if t <= max(times[i])]
# and the same for the RMS values
for i in range(len(rms_models_bone)):
rms_models_bone[i] = rms_models_bone[i][:len(times_model[i])]
rms_models_water[i] = rms_models_water[i][:len(times_model[i])]
# Plot the RMS errors for the reconstructions and the models
# make subplot with 1 row and len(reconstructions) columns
# plot them against the times
plt.figure(figsize=(15, 5))
for i in range(3):
time_model = np.array(times_model[i])
time_recon = np.array(times[i])
plt.subplot(1, len(reconstructions), i + 1)
plt.plot(time_recon, rms_reconstructions_bone[i], label='RMS Mechlem Bone', marker='o')
plt.plot(time_recon, rms_reconstructions_water[i], label='RMS Mechlem Water', marker='o')
plt.plot(time_model, rms_models_bone[i], label='RMS Model Bone', marker='x')
plt.plot(time_model, rms_models_water[i], label='RMS Model Water', marker='x')
plt.xlabel('Time (s)')
plt.ylabel('RMS Error')
plt.ylim([0,0.7])
plt.title(f'Phantom {i + 1}')
plt.legend()
plt.tight_layout()
plt.show()

# import interactive slider for displaying images
from ipywidgets import interact, IntSlider
with open('phantom_results_3.pkl', 'rb') as f:
data = pickle.load(f)
reconstructions = data['reconstructions']
ys = data['ys']
phantoms = data['phantoms']
model_images = data['output_images']
times = data['times']
times_model = data['times_images']
rms_itt_bone = data['rms_reconstructions_bone']
rms_itt_water = data['rms_reconstructions_water']
rms_model_bone = data['rms_models_bone']
rms_model_water = data['rms_models_water']
# loop through reconstructions and display the last bone and water images
# use slider for index in reconstructions array
def display_reconstruction_images(index):
phantom_bone = phantoms[index][0].transpose() # Get the ith bone
phantom_water = phantoms[index][1].transpose() # Get the ith water
reconstruction = reconstructions[index]
_, nMats, nIterates = reconstruction.shape
images = reconstruction.reshape((objectSize, objectSize, objectSize, nMats, nIterates), order='F')
fig, ax = plt.subplots(2, 2, figsize=(10, 5))
ax[0,0].imshow(images[:,:,objectSize // 2, 0, -1], cmap='gray')
ax[0, 0].set_title('Phantom Bone')
ax[0, 0].axis('off')
ax[0, 1].imshow(images[:,:,objectSize // 2, 1, -1], cmap='gray')
ax[0, 1].set_title('Phantom Water')
ax[1, 0].imshow(model_images[index][2][0][:, :, objectSize // 2], cmap='gray')
ax[1, 0].set_title('Model Bone')
ax[1, 1].imshow(model_images[index][2][1][:, :, objectSize // 2], cmap='gray')
ax[1, 1].set_title('Model Water')
plt.show()
interact(display_reconstruction_images,index=IntSlider(min=0, max=len(reconstructions)-1, step=1, value=16, description='Phantom Index'))
Loading...
# display phantom 16, make 2x2 plot with the phantom, model and iterative. one row bone one row water
phantom_index = 16
phantom_bone = phantoms[phantom_index][0].transpose() # Get the ith bone
phantom_water = phantoms[phantom_index][1].transpose() # Get
# the ith water
reconstruction = reconstructions[phantom_index]
_, nMats, nIterates = reconstruction.shape
images = reconstruction.reshape((objectSize, objectSize, objectSize, nMats, nIterates), order='F')
water_reconstruction = images[:, :, :, 1, -1]
bone_reconstruction = images[:, :, :, 0, -1]
model_bone = model_images[phantom_index][2][0][:, :, objectSize // 2]
model_water = model_images[phantom_index][2][1][:, :, objectSize // 2]
fig, ax = plt.subplots(2, 3, figsize=(10, 5))
ax[1, 0].imshow(phantom_bone[:, :, objectSize // 2], cmap='gray')
ax[0, 0].imshow(phantom_water[:, :, objectSize // 2], cmap='gray')
ax[0, 0].set_title('GT')
ax[1, 1].imshow(bone_reconstruction[:, :, objectSize // 2], cmap='gray')
ax[1, 1].axis('off')
ax[1, 2].imshow(model_bone, cmap='gray')
# turn of axis numbers
ax[1,2].set_xlabel('')
ax[1, 2].axis('off')
ax[0, 2].imshow(model_water, cmap='gray')
ax[0, 2].set_title('AttU-Net')
ax[0, 2].axis('off')
ax[0, 1].imshow(water_reconstruction[:, :, objectSize // 2], cmap='gray')
ax[0, 1].set_title('iterative method')
ax[0, 1].axis('off')
ax[0, 0].set_ylabel('Water', fontsize=12)
ax[1, 0].set_ylabel('Bone', fontsize=12)
ax[0, 0].set_xticks([])
ax[0, 0].set_yticks([])
ax[1, 0].set_xticks([])
ax[1, 0].set_yticks([])
plt.tight_layout()
plt.show()
