The scanning sub-problem in fewer dimensions

Hide code cell content
import warnings

import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.base import EnvRollout
from armscan_env.envs.labelmaps_navigation import (
    ArmscanEnv,
    LabelmapClusteringBasedReward,
    LabelmapEnvTerminationCriterion,
)
from armscan_env.envs.observations import (
    ActionRewardObservation,
)
from armscan_env.envs.state_action import ManipulatorAction
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory
from tqdm import tqdm

from tianshou.highlevel.env import EnvMode

config = get_config()
warnings.filterwarnings("ignore", category=UserWarning, module="armscan_env.envs.state_action")
plt.style.use("default")

The scanning sub-problem in fewer dimensions#

def walk_through_env(
    env: ArmscanEnv,
    n_steps: int = 10,
    reset: bool = True,
    show_pbar: bool = True,
    render_title: str = "Labelmap slice",
) -> EnvRollout:
    env_rollout = EnvRollout()

    if reset:
        obs, info = env.reset()
        env.render(title=render_title)

        # add initial state to the rollout
        reward = env.compute_cur_reward()
        terminated = env.should_terminate()
        truncated = env.should_truncate()
        env_rollout.append_reset(
            obs,
            info,
            reward=reward,
            terminated=terminated,
            truncated=truncated,
        )

    env_is_1d = env.action_space.shape == (1,)

    y_lower_bound = -1 if env_is_1d else env.translation_bounds[0]
    y_upper_bound = 1 if env_is_1d else env.translation_bounds[1]

    print(f"Walking through y-axis from {y_lower_bound} to {y_upper_bound} in {n_steps} steps")

    y_actions = np.linspace(y_lower_bound, y_upper_bound, n_steps)
    if show_pbar:
        y_actions = tqdm(y_actions, desc="Step:")

    for y_action in y_actions:
        if not env_is_1d:
            cur_y_action = env.get_optimal_action()
            cur_y_action.translation = (cur_y_action.translation[0], y_action)
            cur_y_action = cur_y_action.to_normalized_array(
                rotation_bounds=env.rotation_bounds,
                translation_bounds=env.translation_bounds,
            )
        else:
            # projected environment
            cur_y_action = np.array([y_action])
        obs, reward, terminated, truncated, info = env.step(cur_y_action)

        env_rollout.append_step(cur_y_action, obs, reward, terminated, truncated, info)
        env.render(title=render_title)
    return env_rollout


def plot_rollout_rewards(env_rollout: EnvRollout, show: bool = True) -> None:
    plt.plot(env_rollout.rewards)

    steps_where_terminated = np.where(env_rollout.terminated)[0]
    # mark the steps where the environment was terminated with a red transparent rectangle
    # and add a legend that red means terminated
    for step in steps_where_terminated:
        plt.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)

    plt.xlabel("Step")
    plt.ylabel("Reward")

    plt.legend(["Reward", "Terminated"])

    if show:
        plt.show()
volumes = load_sitk_volumes(normalize=True)
volume_size = volumes[0].GetSize()

env = ArmscanEnvFactory(
    name2volume={"1": volumes[0]},
    observation=ActionRewardObservation(action_shape=(4,)).to_array_observation(),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    n_stack=2,
).create_env(EnvMode.WATCH)
volume_mm_size = np.array(volumes[0].GetSize()) * np.array(volumes[0].GetSpacing())
print(f"Volume size in mm: {volume_mm_size}")
Volume size in mm: [149.91666922 213.88889253  60.99994183]
env_rollout = walk_through_env(env, 214 // 5)

plot_rollout_rewards(env_rollout)
Walking through y-axis from 0.0 to 213.88889253139496 in 42 steps
Step::   0%|          | 0/42 [00:00<?, ?it/s]
Step::   2%|▏         | 1/42 [00:00<00:07,  5.86it/s]
Step::   5%|▍         | 2/42 [00:00<00:06,  5.87it/s]
Step::   7%|▋         | 3/42 [00:00<00:06,  5.90it/s]
Step::  10%|▉         | 4/42 [00:00<00:06,  5.93it/s]
Step::  12%|█▏        | 5/42 [00:00<00:06,  5.96it/s]
Step::  14%|█▍        | 6/42 [00:01<00:06,  5.88it/s]
Step::  17%|█▋        | 7/42 [00:01<00:05,  5.84it/s]
Step::  19%|█▉        | 8/42 [00:01<00:05,  5.81it/s]
Step::  21%|██▏       | 9/42 [00:01<00:05,  5.83it/s]
Step::  24%|██▍       | 10/42 [00:01<00:05,  5.84it/s]
Step::  26%|██▌       | 11/42 [00:01<00:05,  5.90it/s]
Step::  29%|██▊       | 12/42 [00:02<00:05,  5.95it/s]
Step::  31%|███       | 13/42 [00:02<00:04,  5.94it/s]
Step::  33%|███▎      | 14/42 [00:02<00:04,  5.83it/s]
Step::  36%|███▌      | 15/42 [00:02<00:04,  5.85it/s]
Step::  38%|███▊      | 16/42 [00:02<00:04,  5.80it/s]
Step::  40%|████      | 17/42 [00:02<00:04,  5.81it/s]
Step::  43%|████▎     | 18/42 [00:03<00:04,  5.78it/s]
Step::  45%|████▌     | 19/42 [00:03<00:03,  5.79it/s]
Step::  48%|████▊     | 20/42 [00:03<00:03,  5.85it/s]
Step::  50%|█████     | 21/42 [00:03<00:03,  5.92it/s]
Step::  52%|█████▏    | 22/42 [00:03<00:03,  5.99it/s]
Step::  55%|█████▍    | 23/42 [00:03<00:03,  6.11it/s]
Step::  57%|█████▋    | 24/42 [00:04<00:02,  6.16it/s]
Step::  60%|█████▉    | 25/42 [00:04<00:02,  6.06it/s]
Step::  62%|██████▏   | 26/42 [00:04<00:02,  5.88it/s]
Step::  64%|██████▍   | 27/42 [00:04<00:02,  5.74it/s]
Step::  67%|██████▋   | 28/42 [00:04<00:02,  5.71it/s]
Step::  69%|██████▉   | 29/42 [00:04<00:02,  5.77it/s]
Step::  71%|███████▏  | 30/42 [00:05<00:02,  5.81it/s]
Step::  74%|███████▍  | 31/42 [00:05<00:01,  5.95it/s]
Step::  76%|███████▌  | 32/42 [00:05<00:01,  6.13it/s]
Step::  79%|███████▊  | 33/42 [00:05<00:01,  5.96it/s]
Step::  81%|████████  | 34/42 [00:05<00:01,  5.93it/s]
Step::  83%|████████▎ | 35/42 [00:05<00:01,  6.13it/s]
Step::  86%|████████▌ | 36/42 [00:06<00:00,  6.27it/s]
Step::  88%|████████▊ | 37/42 [00:06<00:00,  6.51it/s]
Step::  90%|█████████ | 38/42 [00:06<00:00,  6.71it/s]
Step::  93%|█████████▎| 39/42 [00:06<00:00,  7.05it/s]
Step::  95%|█████████▌| 40/42 [00:06<00:00,  7.49it/s]
Step::  98%|█████████▊| 41/42 [00:06<00:00,  7.92it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00,  8.26it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00,  6.17it/s]

../_images/a264ca11d446617f9c50619307480d5d773814fe87f8f86c23b56acfbdeadaa8.png

env.get_cur_animation_as_html()

volume = volumes[0]
volume_size = volume.GetSize()

projected_env = ArmscanEnvFactory(
    name2volume={"1": volume},
    observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    n_stack=2,
    project_actions_to="y",
    apply_volume_transformation=True,
).create_env(EnvMode.WATCH)

volume_mm_size = np.array(volume.GetSize()) * np.array(volume.GetSpacing())
print(f"Volume size in mm: {volume_mm_size}, stepping through y-axis per 5 mm")

projected_env_rollout = walk_through_env(
    projected_env,
    int(volume_mm_size[1] // 5),
    render_title="Projected labelmap slice",
)

# Generate the reward plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

# Plot the rewards
ax1.plot(projected_env_rollout.rewards)
steps_where_terminated = np.where(projected_env_rollout.terminated)[0]
for step in steps_where_terminated:
    ax1.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)
ax1.set_xlabel("Step")
ax1.set_ylabel("Reward")
ax1.legend(["Reward", "Terminated"])

# Extract the frontal view at half depth of the used volume
frontal_extent = (0, volume_mm_size[0], volume_mm_size[1], 0)
half_depth = volume.GetSize()[2] // 2
frontal_view = sitk.GetArrayViewFromImage(volume)[half_depth, :, :]
ax2.imshow(frontal_view, extent=frontal_extent)

# Determine the first and last optimal actions
first_optimal_action = projected_env_rollout.actions[
    np.where(projected_env_rollout.terminated)[0][0]
]
full_first_optimal_action = projected_env.get_full_action_array_from_projected_action(
    first_optimal_action,
)
last_optimal_action = projected_env_rollout.actions[
    np.where(projected_env_rollout.terminated)[0][-1]
]
full_last_optimal_action = projected_env.get_full_action_array_from_projected_action(
    last_optimal_action,
)
manipulator_action_1 = ManipulatorAction.from_normalized_array(
    full_first_optimal_action,
    projected_env.rotation_bounds,
    projected_env.translation_bounds,
)
manipulator_action_2 = ManipulatorAction.from_normalized_array(
    full_last_optimal_action,
    projected_env.rotation_bounds,
    projected_env.translation_bounds,
)

x_dash = np.arange(volume_mm_size[0])

# Calculate and clip y-values for the first line
b_1 = manipulator_action_1.translation[1]
y_dash_1 = x_dash * np.tan(np.deg2rad(manipulator_action_1.rotation[0])) + b_1
y_dash_1 = np.clip(y_dash_1, 0, volume_mm_size[1])
ax2.plot(x_dash, y_dash_1, linestyle="--", color="red")

# Calculate and clip y-values for the second line
b_2 = manipulator_action_2.translation[1]
y_dash_2 = x_dash * np.tan(np.deg2rad(manipulator_action_2.rotation[0])) + b_2
y_dash_2 = np.clip(y_dash_2, 0, volume_mm_size[1])
ax2.plot(x_dash, y_dash_2, linestyle="--", color="red")

ax2.legend(["optimal actions range"])
plt.tight_layout()
plt.show()
Volume size in mm: [149.91666922 213.88889253  60.99994183], stepping through y-axis per 5 mm
Walking through y-axis from -1 to 1 in 42 steps
Step::   0%|          | 0/42 [00:00<?, ?it/s]
Step::   2%|▏         | 1/42 [00:00<00:07,  5.61it/s]
Step::   5%|▍         | 2/42 [00:00<00:07,  5.64it/s]
Step::   7%|▋         | 3/42 [00:00<00:06,  5.66it/s]
Step::  10%|▉         | 4/42 [00:00<00:06,  5.69it/s]
Step::  12%|█▏        | 5/42 [00:00<00:06,  5.71it/s]
Step::  14%|█▍        | 6/42 [00:01<00:06,  5.69it/s]
Step::  17%|█▋        | 7/42 [00:01<00:06,  5.70it/s]
Step::  19%|█▉        | 8/42 [00:01<00:05,  5.74it/s]
Step::  21%|██▏       | 9/42 [00:01<00:05,  5.76it/s]
Step::  24%|██▍       | 10/42 [00:01<00:05,  5.83it/s]
Step::  26%|██▌       | 11/42 [00:01<00:05,  5.90it/s]
Step::  29%|██▊       | 12/42 [00:02<00:05,  5.91it/s]
Step::  31%|███       | 13/42 [00:02<00:04,  5.80it/s]
Step::  33%|███▎      | 14/42 [00:02<00:04,  5.81it/s]
Step::  36%|███▌      | 15/42 [00:02<00:04,  5.77it/s]
Step::  38%|███▊      | 16/42 [00:02<00:04,  5.77it/s]
Step::  40%|████      | 17/42 [00:02<00:04,  5.74it/s]
Step::  43%|████▎     | 18/42 [00:03<00:04,  5.76it/s]
Step::  45%|████▌     | 19/42 [00:03<00:03,  5.82it/s]
Step::  48%|████▊     | 20/42 [00:03<00:03,  5.84it/s]
Step::  50%|█████     | 21/42 [00:03<00:03,  5.97it/s]
Step::  52%|█████▏    | 22/42 [00:03<00:03,  6.08it/s]
Step::  55%|█████▍    | 23/42 [00:03<00:03,  6.10it/s]
Step::  57%|█████▋    | 24/42 [00:04<00:02,  6.03it/s]
Step::  60%|█████▉    | 25/42 [00:04<00:02,  5.87it/s]
Step::  62%|██████▏   | 26/42 [00:04<00:02,  5.79it/s]
Step::  64%|██████▍   | 27/42 [00:04<00:02,  5.71it/s]
Step::  67%|██████▋   | 28/42 [00:04<00:02,  5.76it/s]
Step::  69%|██████▉   | 29/42 [00:04<00:02,  5.81it/s]
Step::  71%|███████▏  | 30/42 [00:05<00:02,  5.91it/s]
Step::  74%|███████▍  | 31/42 [00:05<00:01,  6.06it/s]
Step::  76%|███████▌  | 32/42 [00:05<00:01,  6.11it/s]
Step::  79%|███████▊  | 33/42 [00:05<00:01,  5.99it/s]
Step::  81%|████████  | 34/42 [00:05<00:01,  6.13it/s]
Step::  83%|████████▎ | 35/42 [00:05<00:01,  6.24it/s]
Step::  86%|████████▌ | 36/42 [00:06<00:00,  6.51it/s]
Step::  88%|████████▊ | 37/42 [00:06<00:00,  6.70it/s]
Step::  90%|█████████ | 38/42 [00:06<00:00,  6.98it/s]
Step::  93%|█████████▎| 39/42 [00:06<00:00,  7.44it/s]
Step::  95%|█████████▌| 40/42 [00:06<00:00,  7.83it/s]
Step::  98%|█████████▊| 41/42 [00:06<00:00,  8.19it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00,  8.46it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00,  6.18it/s]

../_images/635d57b0e513cab8842a902ba506c5de9a4fd5d4c284a76883c1c9a81d2086ee.png
projected_env.get_cur_animation_as_html()
projected_env.reset()
(array([ 0.   , -0.125,  0.   , -0.125], dtype=float32), {})