In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.base import EnvRollout
from armscan_env.envs.labelmaps_navigation import (
    ArmscanEnv,
    LabelmapClusteringBasedReward,
    LabelmapEnvTerminationCriterion,
)
from armscan_env.envs.observations import (
    ActionRewardObservation,
)
from armscan_env.envs.state_action import ManipulatorAction
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory
from tqdm import tqdm

from tianshou.highlevel.env import EnvMode

config = get_config()
warnings.filterwarnings("ignore", category=UserWarning, module="armscan_env.envs.state_action")
plt.style.use("default")

# The scanning sub-problem in fewer dimensions

In [None]:
def walk_through_env(
    env: ArmscanEnv,
    n_steps: int = 10,
    reset: bool = True,
    show_pbar: bool = True,
    render_title: str = "Labelmap slice",
) -> EnvRollout:
    env_rollout = EnvRollout()

    if reset:
        obs, info = env.reset()
        env.render(title=render_title)

        # add initial state to the rollout
        reward = env.compute_cur_reward()
        terminated = env.should_terminate()
        truncated = env.should_truncate()
        env_rollout.append_reset(
            obs,
            info,
            reward=reward,
            terminated=terminated,
            truncated=truncated,
        )

    env_is_1d = env.action_space.shape == (1,)

    y_lower_bound = -1 if env_is_1d else env.translation_bounds[0]
    y_upper_bound = 1 if env_is_1d else env.translation_bounds[1]

    print(f"Walking through y-axis from {y_lower_bound} to {y_upper_bound} in {n_steps} steps")

    y_actions = np.linspace(y_lower_bound, y_upper_bound, n_steps)
    if show_pbar:
        y_actions = tqdm(y_actions, desc="Step:")

    for y_action in y_actions:
        if not env_is_1d:
            cur_y_action = env.get_optimal_action()
            cur_y_action.translation = (cur_y_action.translation[0], y_action)
            cur_y_action = cur_y_action.to_normalized_array(
                rotation_bounds=env.rotation_bounds,
                translation_bounds=env.translation_bounds,
            )
        else:
            # projected environment
            cur_y_action = np.array([y_action])
        obs, reward, terminated, truncated, info = env.step(cur_y_action)

        env_rollout.append_step(cur_y_action, obs, reward, terminated, truncated, info)
        env.render(title=render_title)
    return env_rollout


def plot_rollout_rewards(env_rollout: EnvRollout, show: bool = True) -> None:
    plt.plot(env_rollout.rewards)

    steps_where_terminated = np.where(env_rollout.terminated)[0]
    # mark the steps where the environment was terminated with a red transparent rectangle
    # and add a legend that red means terminated
    for step in steps_where_terminated:
        plt.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)

    plt.xlabel("Step")
    plt.ylabel("Reward")

    plt.legend(["Reward", "Terminated"])

    if show:
        plt.show()

In [None]:
volumes = load_sitk_volumes(normalize=True)

In [None]:
volume_size = volumes[0].GetSize()

env = ArmscanEnvFactory(
    name2volume={"1": volumes[0]},
    observation=ActionRewardObservation(action_shape=(4,)).to_array_observation(),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    n_stack=2,
).create_env(EnvMode.WATCH)

In [None]:
volume_mm_size = np.array(volumes[0].GetSize()) * np.array(volumes[0].GetSpacing())
print(f"Volume size in mm: {volume_mm_size}")

In [None]:
env_rollout = walk_through_env(env, 214 // 5)

plot_rollout_rewards(env_rollout)

env.get_cur_animation_as_html()

In [None]:
volume = volumes[0]
volume_size = volume.GetSize()

projected_env = ArmscanEnvFactory(
    name2volume={"1": volume},
    observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),
    slice_shape=(volume_size[0], volume_size[2]),
    reward_metric=LabelmapClusteringBasedReward(),
    termination_criterion=LabelmapEnvTerminationCriterion(),
    max_episode_len=10,
    rotation_bounds=(30.0, 10.0),
    translation_bounds=(0.0, None),
    render_mode="animation",
    n_stack=2,
    project_actions_to="y",
    apply_volume_transformation=True,
).create_env(EnvMode.WATCH)

volume_mm_size = np.array(volume.GetSize()) * np.array(volume.GetSpacing())
print(f"Volume size in mm: {volume_mm_size}, stepping through y-axis per 5 mm")

projected_env_rollout = walk_through_env(
    projected_env,
    int(volume_mm_size[1] // 5),
    render_title="Projected labelmap slice",
)

# Generate the reward plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

# Plot the rewards
ax1.plot(projected_env_rollout.rewards)
steps_where_terminated = np.where(projected_env_rollout.terminated)[0]
for step in steps_where_terminated:
    ax1.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)
ax1.set_xlabel("Step")
ax1.set_ylabel("Reward")
ax1.legend(["Reward", "Terminated"])

# Extract the frontal view at half depth of the used volume
frontal_extent = (0, volume_mm_size[0], volume_mm_size[1], 0)
half_depth = volume.GetSize()[2] // 2
frontal_view = sitk.GetArrayViewFromImage(volume)[half_depth, :, :]
ax2.imshow(frontal_view, extent=frontal_extent)

# Determine the first and last optimal actions
first_optimal_action = projected_env_rollout.actions[
    np.where(projected_env_rollout.terminated)[0][0]
]
full_first_optimal_action = projected_env.get_full_action_array_from_projected_action(
    first_optimal_action,
)
last_optimal_action = projected_env_rollout.actions[
    np.where(projected_env_rollout.terminated)[0][-1]
]
full_last_optimal_action = projected_env.get_full_action_array_from_projected_action(
    last_optimal_action,
)
manipulator_action_1 = ManipulatorAction.from_normalized_array(
    full_first_optimal_action,
    projected_env.rotation_bounds,
    projected_env.translation_bounds,
)
manipulator_action_2 = ManipulatorAction.from_normalized_array(
    full_last_optimal_action,
    projected_env.rotation_bounds,
    projected_env.translation_bounds,
)

x_dash = np.arange(volume_mm_size[0])

# Calculate and clip y-values for the first line
b_1 = manipulator_action_1.translation[1]
y_dash_1 = x_dash * np.tan(np.deg2rad(manipulator_action_1.rotation[0])) + b_1
y_dash_1 = np.clip(y_dash_1, 0, volume_mm_size[1])
ax2.plot(x_dash, y_dash_1, linestyle="--", color="red")

# Calculate and clip y-values for the second line
b_2 = manipulator_action_2.translation[1]
y_dash_2 = x_dash * np.tan(np.deg2rad(manipulator_action_2.rotation[0])) + b_2
y_dash_2 = np.clip(y_dash_2, 0, volume_mm_size[1])
ax2.plot(x_dash, y_dash_2, linestyle="--", color="red")

ax2.legend(["optimal actions range"])
plt.tight_layout()
plt.show()

In [None]:
projected_env.get_cur_animation_as_html()
projected_env.reset()