Environment, all in one

Hide code cell content
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.clustering import TissueClusters
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import (
    ArmscanEnv,
    LabelmapClusteringBasedReward,
    LabelmapEnvTerminationCriterion,
)
from armscan_env.envs.observations import (
    LabelmapSliceAsChannelsObservation,
)
from armscan_env.envs.rewards import anatomy_based_rwd
from armscan_env.envs.state_action import ManipulatorAction
from armscan_env.util.visualizations import show_clusters
from armscan_env.volumes.loading import load_sitk_volumes
from celluloid import Camera
from IPython.core.display import HTML

config = get_config()

Environment, all in one#

We can now put everything together in a single environment. We will use the get_volume_slice function to create a 2D slice of the 3D volume, and then we will use the find_DBSCAN_clusters function to find the clusters of pixels that correspond to the different tissues. Finally, we will use the anatomy_based_rwd function to calculate the reward based on the anatomy of the arm.

volumes = load_sitk_volumes(normalize=True)
img_array_1 = sitk.GetArrayFromImage(volumes[0])
img_array_2 = sitk.GetArrayFromImage(volumes[1])
Hide code cell source
t = [160, 155, 150, 148, 146, 142, 140, 140, 115, 120, 125, 125, 130, 130, 135, 138, 140, 140, 140]
z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]
o = volumes[0].GetOrigin()
slice_shape = (volumes[0].GetSize()[0], volumes[0].GetSize()[2])
size = np.array(volumes[0].GetSize()) * np.array(volumes[0].GetSpacing())

transversal_extent = (0, size[0], 0, size[2])
longitudinal_extent = (0, size[1], 0, size[2])
frontal_extent = (0, size[0], size[1], 0)


# Sample functions for demonstration
def linear_function(x: np.ndarray, m: float, b: float) -> np.ndarray:
    return m * x + b


# Create a figure and a gridspec with two rows and two columns
fig = plt.figure(constrained_layout=True, figsize=(8, 6))
gs = fig.add_gridspec(2, 2)
camera = Camera(fig)

# Add subplots
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])

for i in range(len(t)):
    # Subplot 1: Image with dashed line
    ax1.imshow(img_array_1[40, :, :], extent=frontal_extent)
    x_dash = np.arange(size[0])
    b = t[i]
    y_dash = linear_function(x_dash, np.tan(np.deg2rad(z[i])), b)
    ax1.set_title(f"Section {0}")
    line = ax1.plot(x_dash, y_dash, linestyle="--", color="red")[0]
    ax1.set_title("Slice cut")

    # ACTION
    sliced_volume = volumes[0].get_volume_slice(
        slice_shape=slice_shape,
        action=ManipulatorAction(rotation=(z[i], 0.0), translation=(0.0, t[i])),
    )
    sliced_img = sitk.GetArrayFromImage(sliced_volume).T
    ax2.imshow(
        sliced_img.T,
        origin="lower",
        extent=transversal_extent,
    )
    ax2.set_title(f"Slice {i}")

    # OBSERVATION
    clusters = TissueClusters.from_labelmap_slice(sliced_img)
    ax3 = show_clusters(clusters, sliced_img, ax3, extent=transversal_extent)
    ax3.set_title(f"Clusters {i}")

    # REWARD
    loss = anatomy_based_rwd(clusters)
    ax3.text(0, 0, f"Loss: {loss:.2f}", fontsize=12, color="red")

    camera.snap()
    plt.close()
Hide code cell source
animation = camera.animate()
HTML(animation.to_jshtml())

Rotations are defined in degrees, and translations are defined in millimeters. In order for the agent to take meaningful actions, we need to define the action space by bounds. Rotation bounds are set to 180 degrees, since a greater angle can be achieved by rotating in the opposite direction. Translation bounds are set to stay within the image bounds. The physical dimension of the volume is expressed in mm. It is calculated by taking the difference between the physical coordinates of the first and last voxel in the volume.

origin = volumes[0].GetOrigin()
spacing = volumes[0].GetSpacing()
size = volumes[0].GetSize()
end = volumes[0].TransformIndexToPhysicalPoint(size)
print(f"{origin=},\n {spacing=},\n {end=}")
dim = np.subtract(end, origin)
physical_size = size * np.array(spacing)
index_dim = dim / spacing
print(f"{dim=} == {physical_size},\n {index_dim=} == {size=}")
origin=(-74.90050506591797, -106.84154510498047, -30.0),
 spacing=(0.1944444477558136, 0.1944444477558136, 0.9999990463256836),
 end=(75.01616415381423, 107.04734742641455, 30.9999418258667)
dim=array([149.91666922, 213.88889253,  60.99994183]) == [149.91666922 213.88889253  60.99994183],
 index_dim=array([ 771., 1100.,   61.]) == size=(771, 1100, 61)
volume_size = volumes[0].GetSize()

env = ArmscanEnv(
    name2volume={"1": volumes[0], "2": volumes[1]},
    observation=LabelmapSliceAsChannelsObservation(
        slice_shape=(volume_size[0], volume_size[2]),
        action_shape=(4,),
    ),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    apply_volume_transformation=True,
)

observation, info = env.reset()
for _ in range(50):
    action = env.action_space.sample()
    epsilon = 0.1
    if np.random.rand() > epsilon:
        observation, reward, terminated, truncated, info = env.step(action)
    else:
        observation, reward, terminated, truncated, info = env.step_to_optimal_state()
    env.render()

    if terminated or truncated:
        observation, info = env.reset(reset_render=False)
animation = env.get_cur_animation()
env.close()
HTML(animation.to_jshtml())