from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Generic, Self, TypeVar, cast
import numpy as np
import gymnasium as gym
TObs = TypeVar("TObs")
TAction = TypeVar("TAction")
[docs]
class EnvPreconditionError(RuntimeError):
pass
[docs]
@dataclass(kw_only=True)
class StateAction:
normalized_action_arr: Any
# state of the env will be reflected by fields added to subclasses
# but action is a reserved field name. Subclasses should override the
# type of action to be more specific
TStateAction = TypeVar("TStateAction", bound=StateAction)
TEnv = TypeVar("TEnv", bound="ModularEnv")
[docs]
class RewardMetric(Generic[TStateAction], ABC):
[docs]
@abstractmethod
def compute_reward(self, state: TStateAction) -> float:
pass
@property
@abstractmethod
def range(self) -> tuple[float, float]:
pass
[docs]
class TerminationCriterion(Generic[TEnv], ABC):
[docs]
@abstractmethod
def should_terminate(self, env: TEnv) -> bool:
pass
[docs]
class NeverTerminate(TerminationCriterion[Any]):
[docs]
def should_terminate(self, env: Any) -> bool:
return False
[docs]
class Observation(Generic[TStateAction, TObs], ABC):
[docs]
@abstractmethod
def compute_observation(self, state: TStateAction) -> TObs:
pass
@property
@abstractmethod
def observation_space(self) -> gym.spaces.Space[TObs]:
pass
[docs]
class ArrayObservation(Observation[TStateAction, np.ndarray], Generic[TStateAction], ABC):
pass
@property
@abstractmethod
def observation_space(self) -> gym.spaces.Box:
pass
[docs]
class DummyArrayObservation(ArrayObservation[Any]):
[docs]
def compute_observation(self, state: Any) -> np.ndarray:
return np.array([0.5])
@property
def observation_space(self) -> gym.spaces.Space[TObs]:
return gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
[docs]
class DictObservation(Observation[TStateAction, dict[str, np.ndarray]], Generic[TStateAction], ABC):
pass
[docs]
def to_array_observation(self) -> ArrayObservation[TStateAction]:
return ArrayFromDictObservation(self)
@property
@abstractmethod
def observation_space(self) -> gym.spaces.Dict:
pass
[docs]
def merged_with(self, other: Self) -> "MergedDictObservation[TStateAction]":
return MergedDictObservation([self, other])
[docs]
class DummyDictObservation(DictObservation[Any]):
[docs]
def compute_observation(self, state: Any) -> dict[str, np.ndarray]:
return {"dummy": np.array([0.5])}
@property
def observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict(
{"dummy": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)},
)
[docs]
class ConcatenatedArrayObservation(ArrayObservation[TStateAction], Generic[TStateAction]):
def __init__(self, array_observations: list[ArrayObservation[TStateAction]]):
self.array_observations = array_observations
[docs]
def compute_observation(self, state: TStateAction) -> np.ndarray:
return np.concatenate(
[obs.compute_observation(state) for obs in self.array_observations],
axis=1,
)
@cached_property
def observation_space(self) -> gym.spaces.Box:
return self.concatenate_boxes([obs.observation_space for obs in self.array_observations])
[docs]
@staticmethod
def concatenate_boxes(boxes: list[gym.spaces.Box]) -> gym.spaces.Box:
return gym.spaces.Box(
low=np.concatenate([box.low for box in boxes], axis=0),
high=np.concatenate([box.high for box in boxes], axis=0),
)
[docs]
class MergedDictObservation(DictObservation[TStateAction], Generic[TStateAction]):
def __init__(self, dict_observations: list[DictObservation[TStateAction]]):
self._dict_observations = dict_observations
self._merged_obs_dict: dict[str, gym.spaces.Box] = {}
for obs in dict_observations:
if duplicate_keys := self._merged_obs_dict.keys() & obs.observation_space.keys():
raise ValueError(f"Duplicate keys found in observation spaces: {duplicate_keys}")
self._merged_obs_dict.update(obs.observation_space.spaces)
@property
def dict_observations(self) -> list[DictObservation[TStateAction]]:
return self._dict_observations
[docs]
def compute_observation(self, state: TStateAction) -> dict[str, np.ndarray]:
result = {}
for obs in self.dict_observations:
result.update(obs.compute_observation(state))
return result
@cached_property
def observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict(spaces=self._merged_obs_dict)
[docs]
class ArrayFromDictObservation(ArrayObservation[TStateAction], Generic[TStateAction]):
def __init__(self, dict_observation: DictObservation[TStateAction]):
self.dict_observation = dict_observation
[docs]
def compute_observation(self, state: TStateAction) -> np.ndarray:
result_dict = self.dict_observation.compute_observation(state)
return np.concatenate(list(result_dict.values()), axis=0)
@cached_property
def observation_space(self) -> gym.spaces.Box:
return cast(
gym.spaces.Box,
gym.spaces.flatten_space(self.dict_observation.observation_space),
)
[docs]
@dataclass(kw_only=True)
class EnvStatus(Generic[TStateAction, TObs]):
episode_len: int
state_action: TStateAction | None
observation: TObs | None
reward: float | None
is_terminated: bool
is_truncated: bool
is_closed: bool
info: dict[str, Any]
[docs]
class ModularEnv(gym.core.Env[TObs, TAction], Generic[TStateAction, TAction, TObs], ABC):
def __init__(
self,
reward_metric: RewardMetric[TStateAction],
observation: Observation[TStateAction, TObs],
termination_criterion: TerminationCriterion | None = None,
max_episode_len: int | None = None,
):
self.reward_metric = reward_metric
self.observation = observation
self.termination_criterion = termination_criterion or NeverTerminate()
self.max_episode_len = max_episode_len
self._is_closed = True
self._is_terminated = False
self._is_truncated = False
self._cur_episode_len = 0
self._cur_observation: TObs | None = None
self._cur_reward: float | None = None
self._cur_state_action: TStateAction | None = None
[docs]
def get_cur_env_status(self) -> EnvStatus[TStateAction, TObs]:
return EnvStatus(
episode_len=self.cur_episode_len,
state_action=self.cur_state_action,
observation=self.cur_observation,
reward=self.cur_reward,
is_terminated=self.is_terminated,
is_truncated=self._is_truncated,
is_closed=self.is_closed,
info=self.get_info_dict(),
)
@property
def is_closed(self) -> bool:
return self._is_closed
@property
def is_terminated(self) -> bool:
return self._is_terminated
@property
def is_truncated(self) -> bool:
return self._is_truncated
@property
def cur_state_action(self) -> TStateAction | None:
return self._cur_state_action
@property
def cur_observation(self) -> TObs | None:
return self._cur_observation
@property
def cur_reward(self) -> float | None:
return self._cur_reward
@property
def cur_episode_len(self) -> int:
return self._cur_episode_len
@property
@abstractmethod
def action_space(self) -> gym.spaces.Space[TAction]:
pass
@property
def observation_space(self) -> gym.spaces.Space[TObs]:
return self.observation.observation_space
[docs]
def close(self) -> None:
self._cur_state_action = None
self._is_closed = True
self._cur_episode_len = 0
def _assert_cur_state(self) -> None:
if self.cur_state_action is None:
raise EnvPreconditionError(
"This operation requires a current state, but none is set. Did you call reset()?",
)
[docs]
@abstractmethod
def compute_next_state(self, action: TAction) -> TStateAction:
pass
[docs]
@abstractmethod
def sample_initial_state(self) -> TStateAction:
pass
[docs]
def get_info_dict(self) -> dict[str, Any]:
# override this if you want to return additional info
return {}
[docs]
def should_terminate(self) -> bool:
return self.termination_criterion.should_terminate(self)
[docs]
def should_truncate(self) -> bool:
if self.max_episode_len is not None:
return self.cur_episode_len >= self.max_episode_len
return False
[docs]
def compute_cur_observation(self) -> TObs:
self._assert_cur_state()
assert self.cur_state_action is not None
return self.observation.compute_observation(self.cur_state_action)
[docs]
def compute_cur_reward(self) -> float:
self._assert_cur_state()
assert self.cur_state_action is not None
return self.reward_metric.compute_reward(self.cur_state_action)
def _update_cur_reward(self) -> None:
self._cur_reward = self.compute_cur_reward()
def _update_cur_observation(self) -> None:
self._cur_observation = self.compute_cur_observation()
def _update_is_terminated(self) -> None:
self._is_terminated = self.should_terminate()
def _update_is_truncated(self) -> None:
self._is_truncated = self.should_truncate()
def _update_observation_reward_termination(self) -> None:
# NOTE: the order of these calls is important!
self._update_cur_observation()
self._update_cur_reward()
self._update_is_terminated()
self._update_is_truncated()
def _go_to_next_state(self, action: TAction) -> None:
self._cur_state_action = self.compute_next_state(action)
self._update_observation_reward_termination()
[docs]
def reset(self, seed: int | None = None, **kwargs: Any) -> tuple[TObs, dict[str, Any]]:
super().reset(seed=seed, **kwargs)
self._cur_state_action = self.sample_initial_state()
self._is_closed = False
self._cur_episode_len = 0
self._update_observation_reward_termination()
assert self.cur_observation is not None
return self.cur_observation, self.get_info_dict()
[docs]
def step(self, action: TAction) -> tuple[TObs, float, bool, bool, dict[str, Any]]:
"""Step through the environment to navigate to the next state."""
self._go_to_next_state(action)
self._cur_episode_len += 1
assert self.cur_observation is not None
assert self.cur_reward is not None
return (
self.cur_observation,
self.cur_reward,
self.is_terminated,
self.is_truncated,
self.get_info_dict(),
)
[docs]
@dataclass(kw_only=True)
class EnvRollout(Generic[TObs, TAction]):
observations: list[TObs] = field(default_factory=list)
rewards: list[float] = field(default_factory=list)
actions: list[TAction | None] = field(default_factory=list)
infos: list[dict[str, Any]] = field(default_factory=list)
terminated: list[bool] = field(default_factory=list)
truncated: list[bool] = field(default_factory=list)
[docs]
def append_step(
self,
action: TAction,
observation: TObs,
reward: float,
terminated: bool,
truncated: bool,
info: dict[str, Any],
) -> None:
self.observations.append(observation)
self.rewards.append(reward)
self.actions.append(action)
self.terminated.append(terminated)
self.truncated.append(truncated)
self.infos.append(info)
[docs]
def append_reset(
self,
observation: TObs,
info: dict[str, Any],
reward: float = 0,
terminated: bool = False,
truncated: bool = False,
) -> None:
self.observations.append(observation)
self.rewards.append(reward)
self.actions.append(None)
self.infos.append(info)
self.terminated.append(terminated)
self.truncated.append(truncated)