Skip to article content

Deep Learning for Iterative Spectral CT Reconstruction: Replacing Statistical Iterations with an Attention-Based U-Net

The AttU-Net model

An implementation of a U-Net with attention gates added to the skip-connections

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch import optim
import time
from torch.utils.data import Dataset
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm3d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv3d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm3d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv3d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm3d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv3d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm3d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv3d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm3d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv3d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm3d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x, return_attention=False):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        if not return_attention:
            return x * psi
        else:
            return x * psi, psi

class AttU_Net(nn.Module):
    def __init__(self,img_ch=10,output_ch=2):
        super(AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool3d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv3d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return torch.sigmoid(d1)  # Apply sigmoid to the output for binary classification
    
    def getAttenuationMap(self, x):
        attention_maps = {}

        # Encoding path
        x1 = self.Conv1(x)
        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # Decoding path with attention maps
        d5 = self.Up5(x5)
        x4, att5 = self.Att5(g=d5, x=x4, return_attention=True)
        attention_maps['att5'] = att5
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3, att4 = self.Att4(g=d4, x=x3, return_attention=True)
        attention_maps['att4'] = att4
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2, att3 = self.Att3(g=d3, x=x2, return_attention=True)
        attention_maps['att3'] = att3
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1, att2 = self.Att2(g=d2, x=x1, return_attention=True)
        attention_maps['att2'] = att2
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        output = torch.sigmoid(d1)

        return output, attention_maps
        

Custom dataloader

A custom dataloader has been used to load the data individually

class HDF5Dataset(Dataset):
    def __init__(self, h5_path, transform=None, fraction = 1.0):
        self.h5_path = h5_path
        self.transform = transform
        with h5py.File(self.h5_path, 'r') as f:
            self.length = round(f['inputs'].shape[0] * fraction)
            print(self.length)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        # Use one file per worker strategy
        with h5py.File(self.h5_path, 'r') as f:
            x = f['inputs'][index]
            y = f['outputs'][index]

        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.float32)

        if self.transform:
            x = self.transform(x)
            y = self.transform(y)

        return x, y

Training the network

def train_network(network, dataloader, device, lr = 0.001, epochs=20):

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)

    for epoch in range(epochs):
        for i, (input_batch, output_batch) in enumerate(dataloader):
            input_batch = input_batch.to(device)
            output_batch = output_batch.to(device)

            optimizer.zero_grad()
            output = network(input_batch)
            loss = criterion(output, output_batch)
            loss.backward()
            optimizer.step()

            if (i % 1000 == 0):
                open("log.txt", "a").write(f'Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(dataloader)}, Loss: {loss.item()}\n')
                print(f'Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(dataloader)}, Loss: {loss.item()}')
                torch.save(network.state_dict(), f'network_phantom_epoch.pth')


        open("log.txt", "a").write(f'Epoch {epoch + 1}/{epochs}, Basic Data Loss: {loss.item()}\n')
        print(f'Epoch {epoch + 1}/{epochs}, Basic Data Loss: {loss.item()}')
        torch.save(network.state_dict(), f'network_phantom.pth')


    return loss.item()  # Return the loss for the last epoch
Deep Learning for Iterative Spectral CT Reconstruction: Replacing Statistical Iterations with an Attention-Based U-Net
Reconstruction
Deep Learning for Iterative Spectral CT Reconstruction: Replacing Statistical Iterations with an Attention-Based U-Net
Processing input data