Show code cell content
import warnings
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.base import EnvRollout
from armscan_env.envs.labelmaps_navigation import (
ArmscanEnv,
LabelmapClusteringBasedReward,
LabelmapEnvTerminationCriterion,
)
from armscan_env.envs.observations import (
ActionRewardObservation,
)
from armscan_env.envs.state_action import ManipulatorAction
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory
from tqdm import tqdm
from tianshou.highlevel.env import EnvMode
config = get_config()
warnings.filterwarnings("ignore", category=UserWarning, module="armscan_env.envs.state_action")
plt.style.use("default")
The scanning sub-problem in fewer dimensions#
def walk_through_env(
env: ArmscanEnv,
n_steps: int = 10,
reset: bool = True,
show_pbar: bool = True,
render_title: str = "Labelmap slice",
) -> EnvRollout:
env_rollout = EnvRollout()
if reset:
obs, info = env.reset()
env.render(title=render_title)
# add initial state to the rollout
reward = env.compute_cur_reward()
terminated = env.should_terminate()
truncated = env.should_truncate()
env_rollout.append_reset(
obs,
info,
reward=reward,
terminated=terminated,
truncated=truncated,
)
env_is_1d = env.action_space.shape == (1,)
y_lower_bound = -1 if env_is_1d else env.translation_bounds[0]
y_upper_bound = 1 if env_is_1d else env.translation_bounds[1]
print(f"Walking through y-axis from {y_lower_bound} to {y_upper_bound} in {n_steps} steps")
y_actions = np.linspace(y_lower_bound, y_upper_bound, n_steps)
if show_pbar:
y_actions = tqdm(y_actions, desc="Step:")
for y_action in y_actions:
if not env_is_1d:
cur_y_action = env.get_optimal_action()
cur_y_action.translation = (cur_y_action.translation[0], y_action)
cur_y_action = cur_y_action.to_normalized_array(
rotation_bounds=env.rotation_bounds,
translation_bounds=env.translation_bounds,
)
else:
# projected environment
cur_y_action = np.array([y_action])
obs, reward, terminated, truncated, info = env.step(cur_y_action)
env_rollout.append_step(cur_y_action, obs, reward, terminated, truncated, info)
env.render(title=render_title)
return env_rollout
def plot_rollout_rewards(env_rollout: EnvRollout, show: bool = True) -> None:
plt.plot(env_rollout.rewards)
steps_where_terminated = np.where(env_rollout.terminated)[0]
# mark the steps where the environment was terminated with a red transparent rectangle
# and add a legend that red means terminated
for step in steps_where_terminated:
plt.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)
plt.xlabel("Step")
plt.ylabel("Reward")
plt.legend(["Reward", "Terminated"])
if show:
plt.show()
volumes = load_sitk_volumes(normalize=True)
volume_size = volumes[0].GetSize()
env = ArmscanEnvFactory(
name2volume={"1": volumes[0]},
observation=ActionRewardObservation(action_shape=(4,)).to_array_observation(),
slice_shape=(volume_size[0], volume_size[2]),
reward_metric=LabelmapClusteringBasedReward(),
termination_criterion=LabelmapEnvTerminationCriterion(),
max_episode_len=10,
rotation_bounds=(30.0, 10.0),
translation_bounds=(0.0, None),
render_mode="animation",
n_stack=2,
).create_env(EnvMode.WATCH)
volume_mm_size = np.array(volumes[0].GetSize()) * np.array(volumes[0].GetSpacing())
print(f"Volume size in mm: {volume_mm_size}")
Volume size in mm: [149.91666922 213.88889253 60.99994183]
env_rollout = walk_through_env(env, 214 // 5)
plot_rollout_rewards(env_rollout)
Walking through y-axis from 0.0 to 213.88889253139496 in 42 steps
Step:: 0%| | 0/42 [00:00<?, ?it/s]
Step:: 2%|▏ | 1/42 [00:00<00:07, 5.86it/s]
Step:: 5%|▍ | 2/42 [00:00<00:06, 5.87it/s]
Step:: 7%|▋ | 3/42 [00:00<00:06, 5.90it/s]
Step:: 10%|▉ | 4/42 [00:00<00:06, 5.93it/s]
Step:: 12%|█▏ | 5/42 [00:00<00:06, 5.96it/s]
Step:: 14%|█▍ | 6/42 [00:01<00:06, 5.88it/s]
Step:: 17%|█▋ | 7/42 [00:01<00:05, 5.84it/s]
Step:: 19%|█▉ | 8/42 [00:01<00:05, 5.81it/s]
Step:: 21%|██▏ | 9/42 [00:01<00:05, 5.83it/s]
Step:: 24%|██▍ | 10/42 [00:01<00:05, 5.84it/s]
Step:: 26%|██▌ | 11/42 [00:01<00:05, 5.90it/s]
Step:: 29%|██▊ | 12/42 [00:02<00:05, 5.95it/s]
Step:: 31%|███ | 13/42 [00:02<00:04, 5.94it/s]
Step:: 33%|███▎ | 14/42 [00:02<00:04, 5.83it/s]
Step:: 36%|███▌ | 15/42 [00:02<00:04, 5.85it/s]
Step:: 38%|███▊ | 16/42 [00:02<00:04, 5.80it/s]
Step:: 40%|████ | 17/42 [00:02<00:04, 5.81it/s]
Step:: 43%|████▎ | 18/42 [00:03<00:04, 5.78it/s]
Step:: 45%|████▌ | 19/42 [00:03<00:03, 5.79it/s]
Step:: 48%|████▊ | 20/42 [00:03<00:03, 5.85it/s]
Step:: 50%|█████ | 21/42 [00:03<00:03, 5.92it/s]
Step:: 52%|█████▏ | 22/42 [00:03<00:03, 5.99it/s]
Step:: 55%|█████▍ | 23/42 [00:03<00:03, 6.11it/s]
Step:: 57%|█████▋ | 24/42 [00:04<00:02, 6.16it/s]
Step:: 60%|█████▉ | 25/42 [00:04<00:02, 6.06it/s]
Step:: 62%|██████▏ | 26/42 [00:04<00:02, 5.88it/s]
Step:: 64%|██████▍ | 27/42 [00:04<00:02, 5.74it/s]
Step:: 67%|██████▋ | 28/42 [00:04<00:02, 5.71it/s]
Step:: 69%|██████▉ | 29/42 [00:04<00:02, 5.77it/s]
Step:: 71%|███████▏ | 30/42 [00:05<00:02, 5.81it/s]
Step:: 74%|███████▍ | 31/42 [00:05<00:01, 5.95it/s]
Step:: 76%|███████▌ | 32/42 [00:05<00:01, 6.13it/s]
Step:: 79%|███████▊ | 33/42 [00:05<00:01, 5.96it/s]
Step:: 81%|████████ | 34/42 [00:05<00:01, 5.93it/s]
Step:: 83%|████████▎ | 35/42 [00:05<00:01, 6.13it/s]
Step:: 86%|████████▌ | 36/42 [00:06<00:00, 6.27it/s]
Step:: 88%|████████▊ | 37/42 [00:06<00:00, 6.51it/s]
Step:: 90%|█████████ | 38/42 [00:06<00:00, 6.71it/s]
Step:: 93%|█████████▎| 39/42 [00:06<00:00, 7.05it/s]
Step:: 95%|█████████▌| 40/42 [00:06<00:00, 7.49it/s]
Step:: 98%|█████████▊| 41/42 [00:06<00:00, 7.92it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00, 8.26it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00, 6.17it/s]
env.get_cur_animation_as_html()
volume = volumes[0]
volume_size = volume.GetSize()
projected_env = ArmscanEnvFactory(
name2volume={"1": volume},
observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),
slice_shape=(volume_size[0], volume_size[2]),
reward_metric=LabelmapClusteringBasedReward(),
termination_criterion=LabelmapEnvTerminationCriterion(),
max_episode_len=10,
rotation_bounds=(30.0, 10.0),
translation_bounds=(0.0, None),
render_mode="animation",
n_stack=2,
project_actions_to="y",
apply_volume_transformation=True,
).create_env(EnvMode.WATCH)
volume_mm_size = np.array(volume.GetSize()) * np.array(volume.GetSpacing())
print(f"Volume size in mm: {volume_mm_size}, stepping through y-axis per 5 mm")
projected_env_rollout = walk_through_env(
projected_env,
int(volume_mm_size[1] // 5),
render_title="Projected labelmap slice",
)
# Generate the reward plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
# Plot the rewards
ax1.plot(projected_env_rollout.rewards)
steps_where_terminated = np.where(projected_env_rollout.terminated)[0]
for step in steps_where_terminated:
ax1.axvspan(step - 0.5, step + 0.5, color="red", alpha=0.5)
ax1.set_xlabel("Step")
ax1.set_ylabel("Reward")
ax1.legend(["Reward", "Terminated"])
# Extract the frontal view at half depth of the used volume
frontal_extent = (0, volume_mm_size[0], volume_mm_size[1], 0)
half_depth = volume.GetSize()[2] // 2
frontal_view = sitk.GetArrayViewFromImage(volume)[half_depth, :, :]
ax2.imshow(frontal_view, extent=frontal_extent)
# Determine the first and last optimal actions
first_optimal_action = projected_env_rollout.actions[
np.where(projected_env_rollout.terminated)[0][0]
]
full_first_optimal_action = projected_env.get_full_action_array_from_projected_action(
first_optimal_action,
)
last_optimal_action = projected_env_rollout.actions[
np.where(projected_env_rollout.terminated)[0][-1]
]
full_last_optimal_action = projected_env.get_full_action_array_from_projected_action(
last_optimal_action,
)
manipulator_action_1 = ManipulatorAction.from_normalized_array(
full_first_optimal_action,
projected_env.rotation_bounds,
projected_env.translation_bounds,
)
manipulator_action_2 = ManipulatorAction.from_normalized_array(
full_last_optimal_action,
projected_env.rotation_bounds,
projected_env.translation_bounds,
)
x_dash = np.arange(volume_mm_size[0])
# Calculate and clip y-values for the first line
b_1 = manipulator_action_1.translation[1]
y_dash_1 = x_dash * np.tan(np.deg2rad(manipulator_action_1.rotation[0])) + b_1
y_dash_1 = np.clip(y_dash_1, 0, volume_mm_size[1])
ax2.plot(x_dash, y_dash_1, linestyle="--", color="red")
# Calculate and clip y-values for the second line
b_2 = manipulator_action_2.translation[1]
y_dash_2 = x_dash * np.tan(np.deg2rad(manipulator_action_2.rotation[0])) + b_2
y_dash_2 = np.clip(y_dash_2, 0, volume_mm_size[1])
ax2.plot(x_dash, y_dash_2, linestyle="--", color="red")
ax2.legend(["optimal actions range"])
plt.tight_layout()
plt.show()
Volume size in mm: [149.91666922 213.88889253 60.99994183], stepping through y-axis per 5 mm
Walking through y-axis from -1 to 1 in 42 steps
Step:: 0%| | 0/42 [00:00<?, ?it/s]
Step:: 2%|▏ | 1/42 [00:00<00:07, 5.61it/s]
Step:: 5%|▍ | 2/42 [00:00<00:07, 5.64it/s]
Step:: 7%|▋ | 3/42 [00:00<00:06, 5.66it/s]
Step:: 10%|▉ | 4/42 [00:00<00:06, 5.69it/s]
Step:: 12%|█▏ | 5/42 [00:00<00:06, 5.71it/s]
Step:: 14%|█▍ | 6/42 [00:01<00:06, 5.69it/s]
Step:: 17%|█▋ | 7/42 [00:01<00:06, 5.70it/s]
Step:: 19%|█▉ | 8/42 [00:01<00:05, 5.74it/s]
Step:: 21%|██▏ | 9/42 [00:01<00:05, 5.76it/s]
Step:: 24%|██▍ | 10/42 [00:01<00:05, 5.83it/s]
Step:: 26%|██▌ | 11/42 [00:01<00:05, 5.90it/s]
Step:: 29%|██▊ | 12/42 [00:02<00:05, 5.91it/s]
Step:: 31%|███ | 13/42 [00:02<00:04, 5.80it/s]
Step:: 33%|███▎ | 14/42 [00:02<00:04, 5.81it/s]
Step:: 36%|███▌ | 15/42 [00:02<00:04, 5.77it/s]
Step:: 38%|███▊ | 16/42 [00:02<00:04, 5.77it/s]
Step:: 40%|████ | 17/42 [00:02<00:04, 5.74it/s]
Step:: 43%|████▎ | 18/42 [00:03<00:04, 5.76it/s]
Step:: 45%|████▌ | 19/42 [00:03<00:03, 5.82it/s]
Step:: 48%|████▊ | 20/42 [00:03<00:03, 5.84it/s]
Step:: 50%|█████ | 21/42 [00:03<00:03, 5.97it/s]
Step:: 52%|█████▏ | 22/42 [00:03<00:03, 6.08it/s]
Step:: 55%|█████▍ | 23/42 [00:03<00:03, 6.10it/s]
Step:: 57%|█████▋ | 24/42 [00:04<00:02, 6.03it/s]
Step:: 60%|█████▉ | 25/42 [00:04<00:02, 5.87it/s]
Step:: 62%|██████▏ | 26/42 [00:04<00:02, 5.79it/s]
Step:: 64%|██████▍ | 27/42 [00:04<00:02, 5.71it/s]
Step:: 67%|██████▋ | 28/42 [00:04<00:02, 5.76it/s]
Step:: 69%|██████▉ | 29/42 [00:04<00:02, 5.81it/s]
Step:: 71%|███████▏ | 30/42 [00:05<00:02, 5.91it/s]
Step:: 74%|███████▍ | 31/42 [00:05<00:01, 6.06it/s]
Step:: 76%|███████▌ | 32/42 [00:05<00:01, 6.11it/s]
Step:: 79%|███████▊ | 33/42 [00:05<00:01, 5.99it/s]
Step:: 81%|████████ | 34/42 [00:05<00:01, 6.13it/s]
Step:: 83%|████████▎ | 35/42 [00:05<00:01, 6.24it/s]
Step:: 86%|████████▌ | 36/42 [00:06<00:00, 6.51it/s]
Step:: 88%|████████▊ | 37/42 [00:06<00:00, 6.70it/s]
Step:: 90%|█████████ | 38/42 [00:06<00:00, 6.98it/s]
Step:: 93%|█████████▎| 39/42 [00:06<00:00, 7.44it/s]
Step:: 95%|█████████▌| 40/42 [00:06<00:00, 7.83it/s]
Step:: 98%|█████████▊| 41/42 [00:06<00:00, 8.19it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00, 8.46it/s]
Step:: 100%|██████████| 42/42 [00:06<00:00, 6.18it/s]
projected_env.get_cur_animation_as_html()
projected_env.reset()
(array([ 0. , -0.125, 0. , -0.125], dtype=float32), {})