In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.util.visualizations import show_slices

config = get_config()

# Simple Clustering and Linear Sweep Search

Loading the dataset:

In [None]:
volume = sitk.ReadImage(config.get_single_labelmap_path(1))
volume_img = sitk.GetArrayFromImage(volume)
print(f"{volume_img.shape=}")

size = np.array(volume.GetSize()) * np.array(volume.GetSpacing())
print(f"{size=} mm")
transversal_extent = (0, size[0], 0, size[2])
longitudinal_extent = (0, size[1], 0, size[2])
frontal_extent = (0, size[0], size[1], 0)

)In this notebook, we are going to explain the process of searching the carpal tunnel along one axis using a simple clustering algorithm. The first step is to visualize the data to understand the anatomy of the hand at the level of the carpal tunnel. We will then use a simple clustering algorithm to identify the number of features present in each image. Based on the number of features and their relative positions, we will be able to identify the carpal tunnel, basing on its anatomical description.

Since we are not changing the orientation of the slices along the hand, we are bound to a sub-optimal visualization along the axis on which the images have been stacked. This is not exactly transversal to the carpal tunnel, so our anatomical description will be relative to this suboptimal orientation. However, we can still demonstrate that the anatomical description of the region of interest is enough to optimize the navigation.

The following images are slices in proximity of the carpal tunnel area.

In [None]:
show_slices(data=volume_img, start=690, end=20, lap=1, extent=transversal_extent)
plt.show()

The values of the labels are going to be used to identify the clusters of tissues and reason about the anatomy seen in the image.

In [None]:
tissues = {
    "bones": 1,
    "tendons": 2,
    "ulnar": 3,
}

The function `cluster_iter` is going to be used to identify the clusters of tissues in the image. It iterates over the tissues in the dictionary and identifies the clusters of each tissue. The clustering algorithm uses a center-symmetric filter to identify clusters of neighboring pixels with the same value. The algorithm is based on the `label` function from the `scipy.ndimage` package. The function returns a dictionary with the clusters of tissues and the center of each of them.

In [None]:
from armscan_env.clustering import cluster_iter
from armscan_env.util.visualizations import show_clusters

clusters_679 = cluster_iter(tissues, volume_img[:, 679, :].T)
fig = show_clusters(clusters_679, volume_img[:, 679, :].T, extent=transversal_extent)
fig.set_xlabel("X [mm]")
fig.set_ylabel("Z [mm]")
plt.show()

Visualizing slices of the hand at different levels, is going to make clear why it is enough to reason about the anatomy of one slice to identify the region of interest.
The function `anatomy_based_rwd` calculates the score of each image. This offers an observable reward, which can be used to optimize the navigation problem with classical search methods as well as with RL algorithms. The score is based on the number of clusters recognized for each tissue, which should be equal to the `n_landmarks` parameter. If some of the tissues are not present at all, this is more hardly penalized, because it means that the navigation is far off. Moreover, the score takes into account the position of the landmarks: in particular it checks whether the ulnar artery lies underneath the tendons clusters or not. The score is then normalized to sum up to one.

We tuned the score function to our sub-optimal region of interest: it returns a zero loss for the slice showing the described anatomical conformation.

In [None]:
# Create a figure and a gridspec with two rows and two columns
fig = plt.figure(constrained_layout=True, figsize=(12, 6))
gs = fig.add_gridspec(2, 3)

# Add subplots
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[1, 2])


ax1.axhline(y=418, color="red", linestyle="--", label=f"Horizontal Line at Z-index = {418}")
ax1.axhline(y=601, color="red", linestyle="--", label=f"Horizontal Line at Z-index = {601}")
ax1.axhline(y=679, color="red", linestyle="--", label=f"Horizontal Line at Z-index = {679}")
ax1.axhline(y=739, color="red", linestyle="--", label=f"Horizontal Line at Z-index = {739}")
ax1.imshow(volume_img[35, :, :], label="Hand")
ax1.set_xlabel("x: pixels")
ax1.set_ylabel("z: pixels")
ax1.legend()

clusters_418 = cluster_iter(tissues, volume_img[:, 418, :].T)
show_clusters(clusters_418, volume_img[:, 418, :].T, ax2, extent=transversal_extent)
reward_418 = anatomy_based_rwd(clusters_418, n_landmarks=(7, 5, 1))
ax2.text(
    0.95,
    0.1,
    f"Reward: {reward_418:.2f}",
    horizontalalignment="right",
    verticalalignment="bottom",
    transform=ax2.transAxes,
    bbox=dict(facecolor="white", alpha=0.5, edgecolor="black", boxstyle="round,pad=1"),
)
ax2.set_title("Slice 418")
ax2.axis("off")

clusters_601 = cluster_iter(tissues, volume_img[:, 601, :].T)
show_clusters(clusters_601, volume_img[:, 601, :].T, ax3, extent=transversal_extent)
reward_601 = anatomy_based_rwd(clusters_601, n_landmarks=(7, 5, 1))
ax3.text(
    0.95,
    0.1,
    f"Reward: {reward_601:.2f}",
    horizontalalignment="right",
    verticalalignment="bottom",
    transform=ax3.transAxes,
    bbox=dict(facecolor="white", alpha=0.5, edgecolor="black", boxstyle="round,pad=1"),
)
ax3.set_title("Slice 601")
ax3.axis("off")

show_clusters(clusters_679, volume_img[:, 679, :].T, ax4, extent=transversal_extent)
reward_679 = anatomy_based_rwd(clusters_679, n_landmarks=(7, 5, 1))
ax4.text(
    0.95,
    0.1,
    f"Reward: {reward_679:.2f}",
    horizontalalignment="right",
    verticalalignment="bottom",
    transform=ax4.transAxes,
    bbox=dict(facecolor="white", alpha=0.5, edgecolor="black", boxstyle="round,pad=1"),
)
ax4.set_title("Slice 679")
ax4.axis("off")

clusters_739 = cluster_iter(tissues, volume_img[:, 739, :].T)
show_clusters(clusters_739, volume_img[:, 739, :].T, ax5, extent=transversal_extent)
reward_739 = anatomy_based_rwd(clusters_739, n_landmarks=(7, 5, 1))
ax5.text(
    0.95,
    0.1,
    f"Reward: {reward_739:.2f}",
    horizontalalignment="right",
    verticalalignment="bottom",
    transform=ax5.transAxes,
    bbox=dict(facecolor="white", alpha=0.5, edgecolor="black", boxstyle="round,pad=1"),
)
ax5.set_title("Slice 739")
ax5.axis("off")

plt.show()

The clustering algorithm does not just give us information about the number of tissues clusters, but also about their position. Hence, it is possible to reason about the orientation of the image and the relation of the tissues to one another.

In [None]:
clusters_679 = cluster_iter(tissues, volume_img[:, 679, :].T)

bones_centers = [cluster.center for _, cluster in enumerate(clusters_679.bones)]
ligament_centers = [cluster.center for _, cluster in enumerate(clusters_679.tendons)]

bones_center = np.mean(bones_centers, axis=0)
print("bones_center: ", bones_center)
ligament_center = np.mean(ligament_centers, axis=0)
print("ligament_center: ", ligament_center)
ulnar_center = clusters_679.ulnar[0].center
print("ulnar_center: ", ulnar_center)

In [None]:
anatomy_based_rwd(clusters_679, n_landmarks=(7, 5, 1))

Performing a linear sweep search along the axis of the hand, we can identify the optimal region that returns a zero loss. We can also see that the loss converges to zero as we approach the optimal region.

In [None]:
sweep_loss = []
zero_loss_clusters = []

for i in range(volume_img.shape[1]):
    clusters = cluster_iter(tissues, volume_img[:, i, :].T)
    loss = anatomy_based_rwd(clusters, n_landmarks=(7, 5, 1))
    sweep_loss.append(loss)
    if loss == 0:
        zero_loss_clusters.append(clusters)
    print(f"Loss for slice {i}: {sweep_loss[i]}")

In [None]:
plt.plot(range(len(sweep_loss)), sweep_loss, marker="o")

plt.xlabel("Slice index")
plt.ylabel("Score")
plt.title("Score along axial slices")

plt.show()

We can visualize the slices that return a zero loss to check whether this approach is valid.

In [None]:
zero_loss_indices = np.where(np.array(sweep_loss) == 0)[0]
print(f"{len(zero_loss_indices)} indices return a zero loss: ", zero_loss_indices)

nrows = 1
ncols = len(zero_loss_indices) // nrows
indices_to_display = nrows * ncols

if indices_to_display > 0:
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 10))
    axes = axes.flatten()
    for i, idx in enumerate(zero_loss_indices[:indices_to_display]):
        axes[i] = show_clusters(
            tissue_clusters=zero_loss_clusters[i],
            slice=volume_img[:, idx, :].T,
            extent=transversal_extent,
            aspect=2,
            ax=axes[i],
        )
        axes[i].set_title(f"Index: {idx}, Loss: {sweep_loss[idx]:.2f}")
        axes[i].axis("off")

    plt.show()

As we can see, the results are quite promising, but the clustering algorithm is not really identifying the clusters robustly. There are some clusters that are not separated because they have connected pixels. Moreover, it is not possible to tune the clustering algorithm for an expected size, preventing outliers to be detected.

In [None]:
clusters_668 = cluster_iter(tissues, volume_img[:, 668, :].T)
show_clusters(clusters_668, volume_img[:, 668, :].T, extent=transversal_extent, aspect=2)
plt.axis("off")
plt.show()

 In the next notebook, we will show the performance using a different clustering algorithm.