Coverage for src/accsr/config.py : 68%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Contains helpers for defining and providing configuration classes. A typical usage would be to create the files
3*config.py*, *config.json* and *config_local.json* in a project's root directory. An example of a config.py for a
4data-driven project is below. For a non-data driven project, the configuration class should inherit from
5``accsr.config.ConfigurationBase``, and the resulting class will not have any pre-populated public entries.
7>>> from accsr.config import DefaultDataConfiguration, ConfigProviderBase
8>>>
9>>> class __Configuration(DefaultDataConfiguration):
10... @property
11... def custom_entry_from_config(self):
12... return self._get_non_empty_entry("custom_entry_from_config")
13...
14... @property
15... def existing_path_in_base_dir(self):
16... return self._get_existing_path(["base_dir", "path_in_base_dir"])
17...
18... @property
19... def custom_path_in_processed_data(self):
20... return self.datafile_path("my_data", stage=self.PROCESSED, check_existence=False)
21>>>
22>>> class ConfigProvider(ConfigProviderBase[__Configuration]):
23... pass
24>>>
25>>> _config_provider = ConfigProvider()
26>>>
27>>>
28>>> def get_config(reload=False):
29... return _config_provider.get_config(reload=reload)
31"""
34import inspect
35import json
36import logging.handlers
37import os
38from abc import ABC
39from copy import deepcopy
40from pathlib import Path
41from typing import Callable, Dict, Generic, List, TextIO, Type, TypeVar, Union, get_args
43log = logging.getLogger(__name__)
46def recursive_dict_update(d: Dict, u: Dict):
47 """
48 Modifies d inplace by overwriting with non-dict values from u and updating all dict-values recursively.
49 Returns the modified d.
50 """
51 # From https://stackoverflow.com/a/3233356/1069467
52 for k, v in u.items():
53 if isinstance(v, dict):
54 d[k] = recursive_dict_update(d.get(k, {}), v)
55 else:
56 d[k] = v
57 return d
60def _replace_env_vars(conf: Union[dict], env_var_marker="env:"):
61 for k, v in conf.items():
62 if isinstance(v, str) and v.startswith(env_var_marker):
63 env_var_name = v.lstrip(env_var_marker)
64 conf[k] = os.getenv(env_var_name)
65 elif isinstance(v, dict):
66 _replace_env_vars(v, env_var_marker=env_var_marker)
69def _get_entry_with_replaced_env_vars(
70 entry: Union[str, float, list, dict], env_var_marker="env:"
71):
72 entry = deepcopy(entry)
73 if isinstance(entry, str) and entry.startswith(env_var_marker):
74 env_var_name = entry.lstrip(env_var_marker)
75 return os.getenv(env_var_name)
76 if isinstance(entry, dict):
77 _replace_env_vars(entry, env_var_marker=env_var_marker)
78 return entry
79 if isinstance(entry, list):
80 return [_get_entry_with_replaced_env_vars(v) for v in entry]
81 return entry
84def get_config_reader(filename: str) -> Callable[[TextIO], Dict]:
85 """
86 Returns a reader for yaml or json files. The file type is determined by the file extension.
87 """
88 if filename.endswith(".yaml") or filename.endswith(".yml"):
89 import yaml
91 return yaml.safe_load
92 elif filename.endswith(".json"):
93 return json.load
94 raise ValueError(
95 f"Unsupported file type for {filename}. Supported are .yaml, .yml and .json."
96 )
99class ConfigurationBase(ABC):
100 """
101 Base class for reading and retrieving configuration entries. Do not instantiate this class directly but
102 instead inherit from it.
103 """
105 ENV_VAR_MARKER = "env:"
107 def __init__(
108 self,
109 config_directory: str = None,
110 config_files=("config.json", "config_local.json"),
111 ):
112 """
113 :param config_directory: directory where to look for the config files. Typically, this will be a project's
114 root directory. If None, the directory with the module containing the configuration class definition
115 (inherited from ConfigurationBase) will be used.
116 :param config_files: list of JSON or YAML configuration files (relative to config_directory) from which to read.
117 The filenames should end in .json or .yaml/.yml.
118 The configurations will be merged (dicts are merged, everything else is overwritten),
119 entries more to the right have precedence.
120 Non-existing files from the list will be ignored without errors or warnings. However, at least
121 one file must exist for configuration to be read.
122 """
123 self.config_directory = (
124 config_directory
125 if config_directory is not None
126 else self._module_dir_path()
127 )
128 self.config = {}
129 for filename in config_files:
130 file_path = os.path.join(self.config_directory, filename)
131 file_reader = get_config_reader(filename)
132 if os.path.exists(file_path):
133 log.info(f"Reading configuration from {file_path}")
134 with open(file_path, "r") as f:
135 read_config = file_reader(f)
136 recursive_dict_update(self.config, read_config)
137 if not self.config:
138 raise FileNotFoundError(
139 "No configuration entries could be read from"
140 f"{[os.path.join(self.config_directory, c) for c in config_files]}"
141 )
143 def _module_dir_path(self):
144 module_path = os.path.abspath(inspect.getfile(self.__class__))
145 return os.path.dirname(module_path)
147 def _get_non_empty_entry(
148 self, key: Union[str, List[str]]
149 ) -> Union[float, str, List, Dict]:
150 """
151 Retrieves an entry from the configuration
153 :param key: key or list of keys to go through hierarchically
154 :return: the queried json object
155 """
156 if isinstance(key, str):
157 key = [key]
158 value = self.config
159 for k in key:
160 value = value.get(k)
161 if value is None:
162 raise KeyError(f"Value for key '{key}' not set in configuration")
163 return _get_entry_with_replaced_env_vars(value)
165 def _get_existing_path(self, key: Union[str, List[str]], create=True) -> str:
166 """
167 Retrieves an existing local path from the configuration
169 :param key: key or list of keys to go through hierarchically
170 :param create: if True, a directory with the given path will be created on the fly.
171 :return: the queried path
172 """
173 path_string = self._get_non_empty_entry(key)
174 if os.path.isabs(path_string):
175 path = path_string
176 else:
177 path = os.path.abspath(os.path.join(self.config_directory, path_string))
178 if not os.path.exists(path):
179 if isinstance(key, list):
180 key = ".".join(key) # purely for logging
181 if create:
182 log.info(
183 f"Configured directory {key}='{path}' not found; will create it"
184 )
185 os.makedirs(path)
186 else:
187 raise FileNotFoundError(
188 f"Configured directory {key}='{path}' does not exist."
189 )
190 return path.replace("/", os.sep)
192 def _adjusted_path(self, path: str, relative: bool, check_existence: bool):
193 """
194 :param path:
195 :param relative: If true, the returned path will be relative the project's top-level directory.
196 :param check_existence: if True, will raise an error when file does not exist
197 :return: the adjusted path, either absolute or relative
198 """
199 path = os.path.abspath(path)
200 if check_existence and not os.path.exists(path):
201 raise FileNotFoundError(f"No such file: {path}")
202 if relative:
203 return str(Path(path).relative_to(self.config_directory))
204 return path
207class DefaultDataConfiguration(ConfigurationBase, ABC):
208 """
209 Reads default configuration entries and contains retrieval methods for a typical data-driven project.
210 A typical config.json file would look like this:
212 | {
213 | "data_raw": "data/raw",
214 | "data_cleaned": "data/cleaned",
215 | "data_processed": "data/processed",
216 | "data_ground_truth": "data/ground_truth",
217 | "visualizations": "data/visualizations",
218 | "artifacts": "data/artifacts",
219 | "temp": "temp",
220 | "data": "data"
221 | }
223 """
225 PROCESSED = "processed"
226 RAW = "raw"
227 CLEANED = "cleaned"
228 GROUND_TRUTH = "ground_truth"
229 DATA = "data"
231 @property
232 def artifacts(self):
233 return self._get_existing_path("artifacts")
235 @property
236 def visualizations(self):
237 return self._get_existing_path("visualizations")
239 @property
240 def temp(self):
241 return self._get_existing_path("temp")
243 @property
244 def data(self):
245 return self._get_existing_path("data")
247 @property
248 def data_raw(self):
249 return self._get_existing_path("data_raw")
251 @property
252 def data_cleaned(self):
253 return self._get_existing_path("data_cleaned")
255 @property
256 def data_processed(self):
257 return self._get_existing_path("data_processed")
259 @property
260 def data_ground_truth(self):
261 return self._get_existing_path("data_ground_truth")
263 def datafile_path(
264 self,
265 filename: str,
266 stage="raw",
267 relative=False,
268 check_existence=False,
269 ):
270 """
271 :param filename:
272 :param stage: raw, ground_truth, cleaned or processed
273 :param relative: If True, the returned path will be relative the project's top-level directory
274 :param check_existence: if True, will raise an error when file does not exist
275 """
276 basedir = self._data_basedir(stage)
277 full_path = os.path.join(basedir, filename)
278 return self._adjusted_path(full_path, relative, check_existence)
280 def _data_basedir(self, stage):
281 if stage == self.RAW:
282 basedir = self.data_raw
283 elif stage == self.CLEANED:
284 basedir = self.data_cleaned
285 elif stage == self.PROCESSED:
286 basedir = self.data_processed
287 elif stage == self.GROUND_TRUTH:
288 basedir = self.data_ground_truth
289 else:
290 raise KeyError(f"Unknown stage: {stage}")
291 return basedir
293 def artifact_path(self, name: str, relative=False, check_existence=False):
294 """
295 :param name:
296 :param relative: If true, the returned path will be relative the project's top-level directory.
297 :param check_existence: if True, will raise an error when file does not exist
298 :return:
299 """
300 full_path = os.path.join(self.artifacts, name)
301 return self._adjusted_path(full_path, relative, check_existence)
304ConfigurationClass = TypeVar("ConfigurationClass", bound=ConfigurationBase)
307class ConfigProviderBase(Generic[ConfigurationClass], ABC):
308 """
309 Class for providing a config-singleton. Should not be instantiated directly but instead subclassed with an
310 appropriate subclass of ConfigurationBase substituting the generic type.
312 Usage example:
313 >>> from accsr.config import ConfigurationBase, ConfigProviderBase
314 >>> class __MyConfigClass(ConfigurationBase):
315 ... pass
316 >>> class __MyConfigProvider(ConfigProviderBase[__MyConfigClass]):
317 ... pass
318 ...
319 >>> _config_provider = __MyConfigProvider()
320 ...
321 >>> def get_config():
322 ... return _config_provider.get_config()
323 """
325 def __init__(self):
326 self.__config_instance = None
327 self._config_args = None
328 self._config_kwargs = None
329 # retrieving the generic type at runtime, see
330 # https://stackoverflow.com/questions/48572831/how-to-access-the-type-arguments-of-typing-generic
331 self._config_constructor: Type[ConfigurationClass] = get_args(
332 self.__class__.__orig_bases__[0]
333 )[0]
335 def _should_update_config_instance(self, reload: bool, args, kwargs):
336 return (
337 self.__config_instance is None
338 or reload
339 or self._config_args != args
340 or self._config_kwargs != kwargs
341 )
343 def get_config(self, reload=False, *args, **kwargs) -> ConfigurationClass:
344 """
345 Retrieves the configuration object (as singleton).
347 :param reload: if True, the config will be reloaded from disk even if a suitable
348 configuration object already exists. This is mainly useful in interactive environments like notebooks.
349 :param args: passed to init of the configuration class
350 :param kwargs: passed to init of the configuration class constructor
351 :return:
352 """
353 if self._should_update_config_instance(reload, args, kwargs):
354 self._config_args = args
355 self._config_kwargs = kwargs
356 self.__config_instance = self._config_constructor(*args, **kwargs)
357 return self.__config_instance