In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.clustering import (
    TissueClusters,
    TissueLabel,
    find_DBSCAN_clusters,
)
from armscan_env.config import get_config
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.util.visualizations import show_clusters

config = get_config()

# DBSCAN Clustering Search

In [None]:
volume = sitk.ReadImage(config.get_single_labelmap_path(1))
volume_img = sitk.GetArrayFromImage(volume)
print(f"{volume_img.shape=}")

size = np.array(volume.GetSize()) * np.array(volume.GetSpacing())
print(f"{size=} mm")
transversal_extent = (0, size[0], 0, size[2])
longitudinal_extent = (0, size[1], 0, size[2])
frontal_extent = (0, size[0], size[1], 0)

In [None]:
tissues = {
    "bones": 1,
    "tendons": 2,
    "ulnar": 3,
}

DBSCAN might be a better clustering technique that offers more flexibility. It works by defining clusters as continuous regions of high density. It groups together points that are closely packed together (points with many nearby neighbors), marking as outliers points that lie alone in low-density regions (whose nearest neighbors are too far away). The DBSCAN algorithm has two parameters: `min_samples` and `eps`. The `min_samples` parameter specifies the minimum number of points that a cluster must have in order to be considered a cluster. The `eps` parameter specifies the maximum distance between two samples for one to be considered as in the neighborhood of the other.

In [None]:
clusters_680 = TissueClusters(
    bones=find_DBSCAN_clusters(
        TissueLabel.BONES,
        volume_img[:, 680, :].T,
        eps=4.1,
        min_samples=46,
    ),
    tendons=find_DBSCAN_clusters(
        TissueLabel.TENDONS,
        volume_img[:, 680, :].T,
        eps=3.0,
        min_samples=18,
    ),
    ulnar=find_DBSCAN_clusters(
        TissueLabel.ULNAR,
        volume_img[:, 680, :].T,
        eps=2.0,
        min_samples=10,
    ),
)

fig = show_clusters(clusters_680, volume_img[:, 680, :].T, extent=transversal_extent)
fig.set_xlabel("X [mm]")
fig.set_ylabel("Z [mm]")
plt.show()

In general, this algorithm offers a better anatomical description, since it allows to reason about the average dimension of the clusters for each kind of tissue, removing possible outliers given by segmentation errors.

You can play around with the parameter of DBSCAN to find the best tuning, or you can use the constructor `TissueClusters.from_labelmap_slice` which iterates through the tissues to find clusters with predetermined parameters `eps` and `min_samples`.

In [None]:
sweep_loss = []
clusters_list = []

for i in range(volume_img.shape[1]):
    clusters = TissueClusters.from_labelmap_slice(volume_img[:, i, :].T)
    loss = anatomy_based_rwd(clusters, n_landmarks=(7, 3, 1))
    clusters_list.append(clusters)
    print(f"Loss for slice {i}: {loss}")
    sweep_loss.append(loss)

In [None]:
plt.plot(range(len(sweep_loss)), sweep_loss, marker="o")
plt.xlabel("Slice index")
plt.ylabel("Score")
plt.title("Score along axial slices")

plt.show()

The plot shows the score of the clustering along the axial slices. The optimal region corresponds to the anatomical description, as we can see in the following images.

In [None]:
# Find the max reward and show the slice for it
max_loss_idx = np.argmax(sweep_loss)
print(f"Minimum loss: {sweep_loss[max_loss_idx]} at index {max_loss_idx}")
fig, ax = plt.subplots(1, 2, figsize=(14, 7))
show_clusters(clusters_list[max_loss_idx], volume_img[:, max_loss_idx, :].T, ax=ax[0], aspect=6)
ax[1].imshow(volume_img[:, max_loss_idx, :], aspect=6, origin="lower")
for a in ax:
    a.axis("off")
plt.show()

In [None]:
zero_loss_indices = np.where(np.array(sweep_loss) == 0)[0]
print(f"{len(zero_loss_indices)} indices return a zero loss: ", zero_loss_indices)

nrows = 3
ncols = len(zero_loss_indices) // nrows
indices_to_display = nrows * ncols

if indices_to_display > 0:
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 10))
    axes = axes.flatten()
    for i, idx in enumerate(zero_loss_indices[:indices_to_display]):
        axes[i] = show_clusters(
            tissue_clusters=clusters_list[idx],
            slice=volume_img[:, idx, :].T,
            aspect=6,
            ax=axes[i],
        )
        axes[i].set_title(f"Index: {idx}, Loss: {sweep_loss[idx]:.2f}")
        axes[i].axis("off")

    plt.show()

Segmenting the same image with the centro-symmetric clustering and DBSCAN demonstrates the difference in the clustering techniques. The centro-symmetric clustering is more sensitive to the noise in the image, while DBSCAN can be tuned to ignore the noise and focus on the main clusters.

In [None]:
from armscan_env.clustering import cluster_iter

fig, ax = plt.subplots(1, 2, figsize=(14, 7))

clusters_668 = cluster_iter(tissues, volume_img[:, 668, :].T)
show_clusters(clusters_668, volume_img[:, 668, :].T, aspect=6, ax=ax[0])
ax[0].axis("off")
ax[0].set_title("Centrosymmetric Clustering")

clusters_668 = TissueClusters.from_labelmap_slice(volume_img[:, 668, :].T)
show_clusters(clusters_668, volume_img[:, 668, :].T, aspect=6, ax=ax[1])
ax[1].axis("off")
ax[1].set_title("DBSCAN Clustering")
plt.show()