Coverage for src/accsr/remote_storage.py : 89%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import glob
2import json
3import logging.handlers
4import os
5import re
6from contextlib import contextmanager
7from copy import copy
8from dataclasses import asdict, dataclass, field, is_dataclass
9from enum import Enum
10from functools import cached_property
11from pathlib import Path
12from typing import (
13 Any,
14 Callable,
15 Dict,
16 Generator,
17 List,
18 Literal,
19 Optional,
20 Pattern,
21 Protocol,
22 Sequence,
23 Tuple,
24 Union,
25 cast,
26 runtime_checkable,
27)
29from libcloud.storage.base import Container, StorageDriver
30from libcloud.storage.providers import get_driver
31from libcloud.storage.types import (
32 ContainerAlreadyExistsError,
33 InvalidContainerNameError,
34)
35from tqdm import tqdm
37from accsr.files import md5sum
39log = logging.getLogger(__name__)
42def _to_optional_pattern(regex: Optional[Union[str, Pattern]]) -> Optional[Pattern]:
43 if isinstance(regex, str):
44 return re.compile(regex)
45 return regex
48class _SummariesJSONEncoder(json.JSONEncoder):
49 def default(self, o):
50 if isinstance(o, TransactionSummary):
51 # special case for TransactionSummary, since the drivers are not serializable and dataclasses.asdict
52 # calls deepcopy
53 result = copy(o.__dict__)
54 _replace_driver_by_name(result)
55 return result
56 if is_dataclass(o):
57 return asdict(o)
58 if isinstance(o, RemoteObjectProtocol):
59 result = copy(o.__dict__)
60 _replace_driver_by_name(result)
61 return result
62 if isinstance(o, SyncObject):
63 return o.to_dict(make_serializable=True)
64 return str(o)
67def _replace_driver_by_name(obj):
68 # The driver object from libcloud stores a connection and is not serializable.
69 # Since sometimes we want to be able to deepcopy these things around,
70 # we replace the driver by its name. This is needed for `asdict` to work.
71 if isinstance(obj, RemoteObjectProtocol) and hasattr(obj, "driver"):
72 obj.driver = obj.driver.name # type: ignore
73 if isinstance(obj, list) or isinstance(obj, tuple):
74 for item in obj:
75 _replace_driver_by_name(item)
76 if isinstance(obj, dict):
77 for key, value in obj.items():
78 _replace_driver_by_name(value)
81class _JsonReprMixin:
82 def to_json(self) -> str:
83 return json.dumps(self, indent=2, sort_keys=True, cls=_SummariesJSONEncoder)
85 def __repr__(self):
86 return f"\n{self.__class__.__name__}: \n{self.to_json()}\n"
89@contextmanager
90def _switch_to_dir(path: Optional[str] = None) -> Generator[None, None, None]:
91 if path:
92 cur_dir = os.getcwd()
93 try:
94 os.chdir(path)
95 yield
96 finally:
97 os.chdir(cur_dir)
98 else:
99 yield
102class Provider(str, Enum):
103 GOOGLE_STORAGE = "google_storage"
104 S3 = "s3"
105 AZURE_BLOBS = "azure_blobs"
108@runtime_checkable
109class RemoteObjectProtocol(Protocol):
110 """
111 Protocol of classes that describe remote objects. Describes information about the remote object and functionality
112 to download the object.
113 """
115 name: str
116 size: int
117 hash: int
118 provider: Union[Provider, str]
120 def download(
121 self, download_path, overwrite_existing=False
122 ) -> Optional["RemoteObjectProtocol"]:
123 pass
126class SyncObject(_JsonReprMixin):
127 """
128 Class representing the sync-status between a local path and a remote object. Is mainly used for creating
129 summaries and syncing within RemoteStorage and for introspection before and after push/pull transactions.
131 It is not recommended creating or manipulate instances of this class outside RemoteStorage, in particular
132 in user code. This class forms part of the public interface because instances of it are given to users for
133 introspection.
134 """
136 def __init__(
137 self,
138 local_path: Optional[str] = None,
139 remote_obj: Optional[RemoteObjectProtocol] = None,
140 remote_path: Optional[str] = None,
141 remote_obj_overridden_md5_hash: Optional[int] = None,
142 ):
143 """
144 :param local_path: path to the local file
145 :param remote_obj: remote object
146 :param remote_path: path to the remote file (always in linux style)
147 :param remote_obj_overridden_md5_hash: pass this to override the hash of the remote object
148 (by default, the hash attribute of the remote object is used).
149 Setting this might be useful for Azure blob storage, as uploads to it may be chunked,
150 and the md5 hash of the remote object becomes different from the hash of the local file.
151 The hash is used to check if the local and remote files are equal.
152 """
153 if remote_path is not None:
154 remote_path = remote_path.lstrip("/")
155 if remote_obj is not None:
156 remote_obj = copy(remote_obj)
157 remote_obj.name = remote_obj.name.lstrip("/")
159 self.exists_locally = False
160 self.local_path = None
161 self.set_local_path(local_path)
163 if self.local_path is None and remote_obj is None:
164 raise ValueError(
165 f"Either a local path or a remote object has to be passed."
166 )
168 self.remote_obj = remote_obj
170 if remote_path is not None:
171 if remote_obj is not None and remote_obj.name != remote_path:
172 raise ValueError(
173 f"Passed both remote_path and remote_obj but the paths don't agree: "
174 f"{remote_path} != {remote_obj.name}"
175 )
176 self.remote_path = remote_path
177 else:
178 if remote_obj is None:
179 raise ValueError(f"Either remote_path or remote_obj should be not None")
180 self.remote_path = remote_obj.name
182 if self.exists_locally:
183 assert self.local_path is not None
184 self.local_size = os.path.getsize(self.local_path)
185 self.local_hash = md5sum(self.local_path)
186 else:
187 self.local_size = 0
188 self.local_hash = None
190 if remote_obj_overridden_md5_hash is not None:
191 if remote_obj is None:
192 raise ValueError(
193 "remote_obj_overridden_md5_hash can only be set if remote_obj is not None"
194 )
195 self.remote_hash = remote_obj_overridden_md5_hash
196 elif remote_obj is not None:
197 self.remote_hash = remote_obj.hash
198 else:
199 self.remote_hash = None
201 @property
202 def name(self):
203 return self.remote_path
205 @property
206 def exists_on_target(self) -> bool:
207 """
208 True iff the file exists on both locations
209 """
210 return self.exists_on_remote and self.exists_locally
212 def set_local_path(self, path: Optional[str]):
213 """
214 Changes the local path of the SyncObject
215 :param path:
216 :return: None
217 """
218 if path is not None:
219 local_path = os.path.abspath(path)
220 if os.path.isdir(local_path):
221 raise FileExistsError(
222 f"local_path needs to point to file but pointed to a directory: {local_path}"
223 )
224 self.local_path = local_path
225 self.exists_locally = os.path.isfile(local_path)
227 @property
228 def exists_on_remote(self):
229 return self.remote_obj is not None
231 @property
232 def equal_md5_hash_sum(self):
233 if self.exists_on_target:
234 return self.local_hash == self.remote_hash
235 return False
237 def to_dict(self, make_serializable=True):
238 result = copy(self.__dict__)
239 if make_serializable:
240 _replace_driver_by_name(result)
242 result["exists_on_remote"] = self.exists_on_remote
243 result["exists_on_target"] = self.exists_on_target
244 result["equal_md5_hash_sum"] = self.equal_md5_hash_sum
245 return result
248def _get_total_size(objects: Sequence[SyncObject], mode="local"):
249 """
250 Computes the total size of the objects either on the local or on the remote side.
251 :param objects: The SyncObjects for which the size should be computed
252 :param mode: either 'local' or 'remote'
253 :return: the total size of the objects on the specified side
254 """
255 permitted_modes = ["local", "remote"]
256 if mode not in permitted_modes:
257 raise ValueError(f"Unknown mode: {mode}. Has to be in {permitted_modes}.")
258 if len(objects) == 0:
259 return 0
261 def get_size(obj: SyncObject):
262 if mode == "local":
263 if not obj.exists_locally:
264 raise FileNotFoundError(
265 f"Cannot retrieve size of non-existing file: {obj.local_path}"
266 )
267 return obj.local_size
268 if obj.remote_obj is None:
269 raise FileNotFoundError(
270 f"Cannot retrieve size of non-existing remote object corresponding to: {obj.local_path}"
271 )
272 return obj.remote_obj.size
274 return sum([get_size(obj) for obj in objects])
277@dataclass(repr=False)
278class TransactionSummary(_JsonReprMixin):
279 """
280 Class representing the summary of a push or pull operation. Is mainly used for introspection before and after
281 push/pull transactions.
283 It is not recommended creating or manipulate instances of this class outside RemoteStorage, in particular
284 in user code. This class forms part of the public interface because instances of it are given to users for
285 introspection.
286 """
288 matched_source_files: List[SyncObject] = field(default_factory=list)
289 not_on_target: List[SyncObject] = field(default_factory=list)
290 on_target_eq_md5: List[SyncObject] = field(default_factory=list)
291 on_target_neq_md5: List[SyncObject] = field(default_factory=list)
292 unresolvable_collisions: Dict[str, Union[List[RemoteObjectProtocol], str]] = field(
293 default_factory=dict
294 )
295 skipped_source_files: List[SyncObject] = field(default_factory=list)
297 synced_files: List[SyncObject] = field(default_factory=list)
298 sync_direction: Optional[Literal["push", "pull"]] = None
300 def __post_init__(self):
301 if self.sync_direction not in ["pull", "push", None]:
302 raise ValueError(
303 f"sync_direction can only be set to pull, push or None, instead got: {self.sync_direction}"
304 )
306 @property
307 def files_to_sync(self) -> List[SyncObject]:
308 """
309 Returns of files that need synchronization.
311 :return: list of all files that are not on the target or have different md5sums on target and remote
312 """
313 return self.not_on_target + self.on_target_neq_md5
315 def size_files_to_sync(self) -> int:
316 """
317 Computes the total size of all objects that need synchronization. Raises a RuntimeError if the sync_direction
318 property is not set to 'push' or 'pull'.
320 :return: the total size of all local objects that need synchronization if self.sync_direction='push' and
321 the size of all remote files that need synchronization if self.sync_direction='pull'
322 """
323 if self.sync_direction not in ["push", "pull"]:
324 raise RuntimeError(
325 "sync_direction has to be set to push or pull before computing sizes"
326 )
327 mode = "local" if self.sync_direction == "push" else "remote"
328 return _get_total_size(self.files_to_sync, mode=mode)
330 @property
331 def requires_force(self) -> bool:
332 """
333 Getter of the requires_force property.
334 :return: True iff a failure of the transaction can only be prevented by setting force=True.
335 """
336 return len(self.on_target_neq_md5) != 0
338 @property
339 def has_unresolvable_collisions(self) -> bool:
340 """
341 Getter of the requires_force property.
342 :return: True iff there exists a collision that cannot be resolved.
343 """
344 return len(self.unresolvable_collisions) != 0
346 @property
347 def all_files_analyzed(self) -> List[SyncObject]:
348 """
349 Getter of the all_files_analyzed property.
350 :return: list of all analyzed source files
351 """
352 return self.skipped_source_files + self.matched_source_files
354 def add_entry(
355 self,
356 synced_object: Union[SyncObject, str],
357 collides_with: Optional[Union[List[RemoteObjectProtocol], str]] = None,
358 skip: bool = False,
359 ):
360 """
361 Adds a SyncObject to the summary.
362 :param synced_object: either a SyncObject or a path to a local file.
363 :param collides_with: specification of unresolvable collisions for the given sync object
364 :param skip: if True, the object is marked to be skipped
365 :return: None
366 """
367 if isinstance(synced_object, str):
368 synced_object = SyncObject(local_path=synced_object)
369 if skip:
370 self.skipped_source_files.append(synced_object)
371 else:
372 self.matched_source_files.append(synced_object)
373 if collides_with:
374 self.unresolvable_collisions[synced_object.name] = collides_with
375 elif synced_object.exists_on_target:
376 if synced_object.equal_md5_hash_sum:
377 self.on_target_eq_md5.append(synced_object)
378 else:
379 self.on_target_neq_md5.append(synced_object)
380 else:
381 self.not_on_target.append(synced_object)
383 def get_short_summary_dict(self):
384 """
385 Returns a short summary of the transaction as a dictionary.
386 """
387 return {
388 "sync_direction": self.sync_direction,
389 "files_to_sync": len(self.files_to_sync),
390 "total_size": self.size_files_to_sync(),
391 "unresolvable_collisions": len(self.unresolvable_collisions),
392 "synced_files": len(self.synced_files),
393 }
395 def print_short_summary(self):
396 """
397 Prints a short summary of the transaction (shorter than the full repr, which contains
398 information about local and remote objects).
399 """
400 print(json.dumps(self.get_short_summary_dict(), indent=2))
403@dataclass
404class RemoteStorageConfig:
405 """
406 Contains all necessary information to establish a connection
407 to a bucket within the remote storage, and the base path on the remote.
408 """
410 provider: str
411 key: str
412 bucket: str
413 secret: str = field(repr=False)
414 region: Optional[str] = None
415 host: Optional[str] = None
416 port: Optional[int] = None
417 base_path: str = ""
418 secure: bool = True
421class RemoteStorage:
422 """
423 Wrapper around lib-cloud for accessing remote storage services.
424 """
426 def __init__(
427 self,
428 conf: RemoteStorageConfig,
429 add_extra_to_upload: Optional[Callable[[SyncObject], dict]] = None,
430 remote_hash_extractor: Optional[Callable[[RemoteObjectProtocol], int]] = None,
431 ):
432 """
433 :param conf: configuration for the remote storage
434 :param add_extra_to_upload: a function that takes a `SyncObject` and returns a dictionary with extra parameters
435 that should be passed to the `upload_object` method of the storage driver as value of the `extra` kwarg.
436 This can be used to set custom metadata or other parameters. For example, for Azure blob storage, one can
437 set the hash of the local file as metadata by using
438 `add_extra_to_upload = lambda sync_object: {"meta_data": {"md5": sync_object.local_hash}}`.
439 :param remote_hash_extractor: a function that extracts the hash from a `RemoteObjectProtocol` object.
440 This is useful for Azure blob storage, as uploads to may be chunked, and the md5 hash of the remote object
441 becomes different from the hash of the local file. In that case, one can add the hash of the local file
442 to the metadata using `add_extra_to_upload`, and then use this function to extract the hash from the
443 remote object. If not set, the `hash` attribute of the `RemoteObjectProtocol` object is used.
444 """
445 self._bucket: Optional[Container] = None
446 self._conf = conf
447 self._provider = conf.provider
448 self._remote_base_path = ""
449 self.set_remote_base_path(conf.base_path)
450 possible_driver_kwargs = {
451 "key": self.conf.key,
452 "secret": self.conf.secret,
453 "region": self.conf.region,
454 "host": self.conf.host,
455 "port": self.conf.port,
456 "secure": self.conf.secure,
457 }
458 self.driver_kwargs = {
459 k: v for k, v in possible_driver_kwargs.items() if v is not None
460 }
461 self.add_extra_to_upload = add_extra_to_upload
462 self.remote_hash_extractor = remote_hash_extractor
464 def create_bucket(self, exist_ok: bool = True):
465 try:
466 log.info(
467 f"Creating bucket {self.conf.bucket} from configuration {self.conf}."
468 )
469 self.driver.create_container(container_name=self.conf.bucket)
470 except (ContainerAlreadyExistsError, InvalidContainerNameError):
471 if not exist_ok:
472 raise
473 log.info(
474 f"Bucket {self.conf.bucket} already exists (or the name was invalid)."
475 )
477 @property
478 def conf(self) -> RemoteStorageConfig:
479 return self._conf
481 @property
482 def provider(self) -> str:
483 return self._provider
485 @property
486 def remote_base_path(self) -> str:
487 return self._remote_base_path
489 def set_remote_base_path(self, path: Optional[str]):
490 """
491 Changes the base path in the remote storage
492 (overriding the base path extracted from RemoteStorageConfig during instantiation).
493 Pull and push operations will only affect files within the remote base path.
495 :param path: a path with linux-like separators
496 """
497 if path is None:
498 path = ""
499 else:
500 # google storage pulling and listing does not work with paths starting with "/"
501 path = path.strip().lstrip("/")
502 self._remote_base_path = path.strip()
504 @cached_property
505 def bucket(self) -> Container:
506 log.info(f"Establishing connection to bucket {self.conf.bucket}")
507 return self.driver.get_container(self.conf.bucket)
509 @cached_property
510 def driver(self) -> StorageDriver:
511 storage_driver_factory = get_driver(self.provider)
512 return storage_driver_factory(**self.driver_kwargs)
514 def _execute_sync(
515 self, sync_object: SyncObject, direction: Literal["push", "pull"], force=False
516 ) -> SyncObject:
517 """
518 Synchronizes the local and the remote file in the given direction. Will raise an error if a file from the source
519 would overwrite an already existing file on the target and force=False. In this case, no operations will be
520 performed on the target.
522 :param sync_object: instance of SyncObject that will be used as basis for synchronization. Usually
523 created from a get_*_summary method.
524 :param direction: either "push" or "pull"
525 :param force: if True, all already existing files on the target (with a different md5sum than the source files)
526 will be overwritten.
527 :return: a SyncObject that represents the status of remote and target after the synchronization
528 """
529 if sync_object.equal_md5_hash_sum:
530 log.debug(
531 f"Skipping {direction} of {sync_object.name} because of coinciding hash sums"
532 )
533 return sync_object
535 if sync_object.exists_on_target and not force:
536 raise ValueError(
537 f"Cannot perform {direction} because {sync_object.name} already exists and force is False"
538 )
540 if direction == "push":
541 if not sync_object.exists_locally:
542 raise FileNotFoundError(
543 f"Cannot push non-existing file: {sync_object.local_path}"
544 )
545 assert sync_object.local_path is not None
547 extra = (
548 self.add_extra_to_upload(sync_object)
549 if self.add_extra_to_upload is not None
550 else None
551 )
552 remote_obj = cast(
553 RemoteObjectProtocol,
554 self.bucket.upload_object(
555 sync_object.local_path,
556 sync_object.remote_path,
557 extra=extra,
558 verify_hash=False,
559 ),
560 )
562 if self.remote_hash_extractor is not None:
563 remote_obj_overridden_md5_hash = self.remote_hash_extractor(remote_obj)
564 else:
565 remote_obj_overridden_md5_hash = None
566 return SyncObject(
567 sync_object.local_path,
568 remote_obj,
569 remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
570 )
572 elif direction == "pull":
573 if None in [sync_object.remote_obj, sync_object.local_path]:
574 raise RuntimeError(
575 f"Cannot pull without remote object and local path. Affects: {sync_object.name}"
576 )
577 assert sync_object.local_path is not None
578 if os.path.isdir(sync_object.local_path):
579 raise FileExistsError(
580 f"Cannot pull file to a path which is an existing directory: {sync_object.local_path}"
581 )
583 log.debug(f"Fetching {sync_object.remote_obj.name} from {self.bucket.name}")
584 os.makedirs(os.path.dirname(sync_object.local_path), exist_ok=True)
585 sync_object.remote_obj.download(
586 sync_object.local_path, overwrite_existing=force
587 )
588 return SyncObject(sync_object.local_path, sync_object.remote_obj)
589 else:
590 raise ValueError(
591 f"Unknown direction {direction}, has to be either 'push' or 'pull'."
592 )
594 @staticmethod
595 def _get_remote_path(remote_obj: RemoteObjectProtocol) -> str:
596 """
597 Returns the full path to the remote object. The resulting path never starts with "/" as it can cause problems
598 with some backends (e.g. google cloud storage).
599 """
600 return remote_obj.name.lstrip("/")
602 def _get_relative_remote_path(self, remote_obj: RemoteObjectProtocol) -> str:
603 """
604 Returns the path to the remote object relative to configured base dir (as expected by pull for a single file)
605 """
606 result = remote_obj.name
607 result = result[len(self.remote_base_path) :]
608 result = result.lstrip("/")
609 return result
611 def _full_remote_path(self, remote_path: str) -> str:
612 """
613 :param remote_path: remote_path on storage bucket relative to the configured remote base remote_path.
614 e.g. 'data/some_file.json'
615 :return: full remote remote_path on storage bucket. With the example above gives
616 "remote_base_path/data/some_file.json". Does not start with "/" even if remote_base_path is empty
617 """
618 # in google cloud paths cannot begin with / for pulling or listing (for pushing they can though...)
619 remote_path = "/".join([self.remote_base_path, remote_path])
620 return remote_path.lstrip("/")
622 @staticmethod
623 def _listed_due_to_name_collision(
624 full_remote_path: str, remote_object: RemoteObjectProtocol
625 ) -> bool:
626 """
627 Checks whether a remote object was falsely listed because its name starts with the same
628 characters as full_remote_path.
630 Example 1: full remote path is 'pull/this/dir' and remote storage includes paths like 'pull/this/dir_subfix'.
631 Example 2: full remote path is 'delete/this/file' and remote storage includes paths like 'delete/this/file_2'.
633 All such paths will be listed in bucket.list_objects(full_remote_path), and we need to exclude them in
634 most methods like pull or delete.
636 :param full_remote_path: usually the output of self._full_remote_path(remote_path)
637 :param remote_object: the object to check
638 :return:
639 """
640 # no name collisions possible in this case
641 if full_remote_path.endswith("/") or full_remote_path == "":
642 return False
644 # Remove leading / for comparison of paths
645 full_remote_path = full_remote_path.lstrip("/")
646 object_remote_path = RemoteStorage._get_remote_path(remote_object)
647 is_in_selected_dir = object_remote_path.startswith(full_remote_path + "/")
648 is_selected_file = object_remote_path == full_remote_path
649 return not (is_in_selected_dir or is_selected_file)
651 def _execute_sync_from_summary(
652 self, summary: TransactionSummary, dryrun: bool = False, force: bool = False
653 ) -> TransactionSummary:
654 """
655 Executes a transaction summary.
656 :param summary: The transaction summary
657 :param dryrun: if True, logs any error that would have prevented the execution and returns the summary
658 without actually executing the sync.
659 :param force: raises an error if dryrun=False and any files would be overwritten by the sync
660 :return: Returns the input transaction summary. Note that the function potentially alters the state of the
661 input summary.
662 """
663 if dryrun:
664 log.info(f"Skipping {summary.sync_direction} because dryrun=True")
665 if summary.has_unresolvable_collisions:
666 log.warning(
667 f"This transaction has unresolvable collisions and would not succeed."
668 )
669 if summary.requires_force and not force:
670 log.warning(
671 f"This transaction requires overwriting of files and would not succeed without force=True"
672 )
673 return summary
675 if summary.has_unresolvable_collisions:
676 raise FileExistsError(
677 f"Found name collisions files with directories, not syncing anything. "
678 f"Suggestion: perform a dryrun and analyze the summary. "
679 f"Affected names: {list(summary.unresolvable_collisions.keys())}. "
680 )
682 if summary.requires_force and not force:
683 raise FileExistsError(
684 f"Operation requires overwriting of objects but `force=False`. "
685 f"Suggestion: perform a dryrun and analyze the summary. "
686 f"Affected names: {[obj.name for obj in summary.on_target_neq_md5]}. "
687 )
689 desc = f"{summary.sync_direction}ing (bytes)"
690 if force:
691 desc = "force " + desc
692 with tqdm(total=summary.size_files_to_sync(), desc=desc) as pbar:
693 for sync_obj in summary.files_to_sync:
694 assert summary.sync_direction is not None
695 synced_obj = self._execute_sync(
696 sync_obj, direction=summary.sync_direction, force=force
697 )
698 pbar.update(synced_obj.local_size)
699 summary.synced_files.append(synced_obj)
700 return summary
702 def pull(
703 self,
704 remote_path: str,
705 local_base_dir: str = "",
706 force: bool = False,
707 include_regex: Optional[Union[Pattern, str]] = None,
708 exclude_regex: Optional[Union[Pattern, str]] = None,
709 convert_to_linux_path: bool = True,
710 dryrun: bool = False,
711 path_regex: Optional[Union[Pattern, str]] = None,
712 strip_abspath_prefix: Optional[str] = None,
713 strip_abs_local_base_dir: bool = True,
714 ) -> TransactionSummary:
715 r"""
716 Pull either a file or a directory under the given path relative to local_base_dir.
718 :param remote_path: remote path on storage bucket relative to the configured remote base path.
719 e.g. 'data/ground_truth/some_file.json'. Can also be an absolute local path if ``strip_abspath_prefix``
720 is specified.
721 :param local_base_dir: Local base directory for constructing local path
722 e.g. passing 'local_base_dir' will download to the path
723 'local_base_dir/data/ground_truth/some_file.json' in the above example
724 :param force: If False, pull will raise an error if an already existing file deviates from the remote in
725 its md5sum. If True, these files are overwritten.
726 :param include_regex: If not None only files with paths matching the regex will be pulled. This is useful for
727 filtering files within a remote directory before pulling them.
728 :param exclude_regex: If not None, files with paths matching the regex will be excluded from the pull.
729 Takes precedence over ``include_regex``, i.e. if a file matches both, it will be excluded.
730 :param convert_to_linux_path: if True, will convert windows path to linux path (as needed by remote storage) and
731 thus passing a remote path like 'data\my\path' will be converted to 'data/my/path' before pulling.
732 This should only be set to False if you want to pull a remote object with '\' in its file name
733 (which is discouraged).
734 :param dryrun: If True, simulates the pull operation and returns the remote objects that would have been pulled.
735 :param path_regex: DEPRECATED! Use ``include_regex`` instead.
736 :param strip_abspath_prefix: Will only have an effect if the `remote_path` is absolute.
737 Then the given prefix is removed from it before pulling. This is useful for pulling files from a remote storage
738 by directly specifying absolute local paths instead of first converting them to actual remote paths.
739 Similar in logic to `local_path_prefix` in `push`.
740 A common use case is to always set `local_base_dir` to the same value and to always pass absolute paths
741 as `remote_path` to `pull`.
742 :param strip_abs_local_base_dir: If True, and `local_base_dir` is an absolute path, then
743 the `local_base_dir` will be treated as `strip_abspath_prefix`. See explanation of `strip_abspath_prefix`.
744 :return: An object describing the summary of the operation.
745 """
747 if strip_abs_local_base_dir and os.path.isabs(local_base_dir):
748 if strip_abspath_prefix is not None:
749 raise ValueError(
750 f"Cannot specify both `strip_abs_local_base_dir`={strip_abs_local_base_dir} "
751 f"and `strip_abspath_prefix`={strip_abspath_prefix}"
752 f"when `local_base_dir`={local_base_dir} is an absolute path."
753 )
754 strip_abspath_prefix = local_base_dir
756 remote_path_is_abs = remote_path.startswith("/") or os.path.isabs(remote_path)
758 if strip_abspath_prefix is not None and remote_path_is_abs:
759 remote_path = remote_path.replace("\\", "/")
760 strip_abspath_prefix = strip_abspath_prefix.replace("\\", "/").rstrip("/")
761 if not remote_path.startswith(strip_abspath_prefix):
762 raise ValueError(
763 f"Remote path {remote_path} is absolute but does not start "
764 f"with the given prefix {strip_abspath_prefix}"
765 )
766 # +1 for removing the leading '/'
767 remote_path = remote_path[len(strip_abspath_prefix) + 1 :]
769 include_regex = self._handle_deprecated_path_regex(include_regex, path_regex)
770 summary = self._get_pull_summary(
771 remote_path,
772 local_base_dir,
773 include_regex=include_regex,
774 exclude_regex=exclude_regex,
775 convert_to_linux_path=convert_to_linux_path,
776 )
777 if len(summary.all_files_analyzed) == 0:
778 log.warning(f"No files found in remote storage under path: {remote_path}")
779 return self._execute_sync_from_summary(summary, dryrun=dryrun, force=force)
781 def _get_destination_path(
782 self, obj: RemoteObjectProtocol, local_base_dir: str
783 ) -> str:
784 """
785 Return the destination path of the given object
786 """
787 relative_obj_path = self._get_relative_remote_path(obj)
788 return os.path.join(local_base_dir, relative_obj_path)
790 def _get_pull_summary(
791 self,
792 remote_path: str,
793 local_base_dir: str = "",
794 include_regex: Optional[Union[Pattern, str]] = None,
795 exclude_regex: Optional[Union[Pattern, str]] = None,
796 convert_to_linux_path: bool = True,
797 path_regex: Optional[Union[Pattern, str]] = None,
798 ) -> TransactionSummary:
799 r"""
800 Creates TransactionSummary of the specified pull operation.
802 :param remote_path: remote path on storage bucket relative to the configured remote base path.
803 e.g. 'data/ground_truth/some_file.json'
804 :param local_base_dir: Local base directory for constructing local path.
805 Example: passing 'local_base_dir' will download to the path
806 'local_base_dir/data/ground_truth/some_file.json' in the above example
807 :param include_regex: If not None, only files with paths matching the regex will be pulled. This is useful for
808 filtering files within a remote directory before pulling them.
809 :param exclude_regex: If not None, only files with paths not matching the regex will be pulled.
810 Takes precedence over include_regex, i.e. if a file matches both, it will be excluded.
811 :param convert_to_linux_path: if True, will convert windows path to linux path (as needed by remote storage) and
812 thus passing a remote path like 'data\my\path' will be converted to 'data/my/path' before pulling.
813 This should only be set to False if you want to pull a remote object with '\' in its file name
814 (which is discouraged).
815 :param path_regex: DEPRECATED! use ``include_regex`` instead.
816 :return:
817 """
818 include_regex = self._handle_deprecated_path_regex(include_regex, path_regex)
820 include_regex = _to_optional_pattern(include_regex)
821 exclude_regex = _to_optional_pattern(exclude_regex)
823 local_base_dir = os.path.abspath(local_base_dir)
824 if convert_to_linux_path:
825 remote_path = remote_path.replace("\\", "/")
827 summary = TransactionSummary(sync_direction="pull")
828 full_remote_path = self._full_remote_path(remote_path)
829 # noinspection PyTypeChecker
830 remote_objects = cast(
831 List[RemoteObjectProtocol], list(self.bucket.list_objects(full_remote_path))
832 )
834 for remote_obj in tqdm(
835 remote_objects,
836 desc=f"Scanning remote paths in {self.bucket.name}/{full_remote_path}: ",
837 ):
838 local_path = None
839 collides_with = None
840 if (remote_obj.size == 0) or (
841 self._listed_due_to_name_collision(full_remote_path, remote_obj)
842 ):
843 log.debug(
844 f"Skipping {remote_obj.name} since it was listed due to name collisions"
845 )
846 skip = True
847 else:
848 relative_obj_path = self._get_relative_remote_path(remote_obj)
849 skip = self._should_skip(
850 relative_obj_path, include_regex, exclude_regex
851 )
853 if not skip:
854 local_path = self._get_destination_path(remote_obj, local_base_dir)
855 if os.path.isdir(local_path):
856 collides_with = local_path
858 remote_obj_overridden_md5_hash = (
859 self.remote_hash_extractor(remote_obj)
860 if self.remote_hash_extractor is not None
861 else None
862 )
863 sync_obj = SyncObject(
864 local_path=local_path,
865 remote_obj=remote_obj,
866 remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
867 )
869 summary.add_entry(
870 sync_obj,
871 skip=skip,
872 collides_with=collides_with,
873 )
875 return summary
877 def get_push_remote_path(self, local_path: str) -> str:
878 """
879 Get the full path within a remote storage bucket for pushing.
881 :param local_path: the local path to the file
882 :return: the remote path that corresponds to the local path
883 """
884 return (
885 "/".join([self.remote_base_path, local_path])
886 .replace(os.sep, "/")
887 .lstrip("/")
888 )
890 def _get_push_summary(
891 self,
892 path: str,
893 local_path_prefix: Optional[str] = None,
894 include_regex: Optional[Union[Pattern, str]] = None,
895 exclude_regex: Optional[Union[Pattern, str]] = None,
896 path_regex: Optional[Union[Pattern, str]] = None,
897 ) -> TransactionSummary:
898 """
899 Retrieves the summary of the push-transaction plan, before it has been executed.
900 Nothing will be pushed and the synced_files entry of the summary will be an empty list.
902 :param path: Path to the local object (file or directory) to be uploaded, may be absolute or relative.
903 globs are permitted, thus ``path`` may contain wildcards.
904 :param local_path_prefix: path names on the remote will be relative to this path. Thus, specifying
905 for example ``local_path_prefix=/bar/foo`` (on a unix system) and ``path=baz``
906 will push ``/bar/foo/baz`` to ``remote_base_path/baz``. The same will happen if
907 ``path=/bar/foo/baz`` is specified.
908 **NOTE**: if ``local_path_prefix`` is specified and ``path`` is absolute, it is assumed that
909 ``path`` is child of ``local_path_prefix``. If this is not the case, an error will be raised.
910 :param include_regex: If not None, only files with paths matching the regex will be pushed.
911 Note that paths matched against the regex will be relative to ``local_path_prefix``.
912 :param exclude_regex: If not None, only files with paths not matching the regex will be pushed.
913 Takes precedence over ``include_regex``, i.e. if a file matches both regexes, it will be excluded.
914 Note that paths matched against the regex will be relative to ``local_path_prefix``.
915 :param path_regex: DEPRECATED! Same as ``include_regex``.
916 :return: the summary object
917 """
918 summary = TransactionSummary(sync_direction="push")
919 include_regex = self._handle_deprecated_path_regex(include_regex, path_regex)
921 if local_path_prefix is not None:
922 local_path_prefix = os.path.abspath(local_path_prefix)
923 include_regex = _to_optional_pattern(include_regex)
924 exclude_regex = _to_optional_pattern(exclude_regex)
926 _path = Path(path)
927 if _path.is_absolute() and local_path_prefix:
928 try:
929 path = str(_path.relative_to(local_path_prefix))
930 except ValueError:
931 raise ValueError(
932 f"Specified {path=} is not a child of {local_path_prefix=}"
933 )
935 # at this point, path is relative to local_path_prefix.
936 with _switch_to_dir(local_path_prefix):
937 # collect all paths to scan
938 all_files_analyzed = []
939 for local_path in glob.glob(path):
940 if os.path.isfile(local_path):
941 all_files_analyzed.append(local_path)
942 elif os.path.isdir(local_path):
943 for root, _, fs in os.walk(local_path):
944 all_files_analyzed.extend([os.path.join(root, f) for f in fs])
945 if len(all_files_analyzed) == 0:
946 raise FileNotFoundError(
947 f"No files found under {path=} with {local_path_prefix=}"
948 )
950 for file in tqdm(
951 all_files_analyzed,
952 desc=f"Scanning files in {os.path.join(os.getcwd(), path)}: ",
953 ):
954 collides_with = None
955 remote_obj = None
956 skip = self._should_skip(file, include_regex, exclude_regex)
958 remote_path = self.get_push_remote_path(file)
960 all_matched_remote_obj = cast(
961 List[RemoteObjectProtocol], self.bucket.list_objects(remote_path)
962 )
963 matched_remote_obj = [
964 obj
965 for obj in all_matched_remote_obj
966 if not self._listed_due_to_name_collision(remote_path, obj)
967 ]
969 # name collision of local file with remote dir
970 if len(matched_remote_obj) > 1:
971 collides_with = matched_remote_obj
973 elif matched_remote_obj:
974 remote_obj = matched_remote_obj[0]
975 remote_obj_overridden_md5_hash = (
976 self.remote_hash_extractor(remote_obj)
977 if self.remote_hash_extractor is not None and remote_obj is not None
978 else None
979 )
980 synced_obj = SyncObject(
981 local_path=file,
982 remote_obj=remote_obj,
983 remote_path=remote_path,
984 remote_obj_overridden_md5_hash=remote_obj_overridden_md5_hash,
985 )
986 summary.add_entry(
987 synced_obj,
988 collides_with=collides_with,
989 skip=skip,
990 )
992 return summary
994 @staticmethod
995 def _should_skip(
996 file: str, include_regex: Optional[Pattern], exclude_regex: Optional[Pattern]
997 ):
998 if include_regex is not None and not include_regex.match(file):
999 log.debug(
1000 f"Skipping {file} since it does not match regular expression '{include_regex}'."
1001 )
1002 return True
1003 if exclude_regex is not None and exclude_regex.match(file):
1004 log.debug(
1005 f"Skipping {file} since it matches regular expression '{exclude_regex}'."
1006 )
1007 return True
1008 return False
1010 @staticmethod
1011 def _handle_deprecated_path_regex(
1012 include_regex: Optional[Union[Pattern, str]],
1013 path_regex: Optional[Union[Pattern, str]],
1014 ):
1015 if path_regex is not None:
1016 log.warning(
1017 "Using deprecated parameter 'path_regex'. Use 'include_regex' instead."
1018 )
1019 if include_regex is not None:
1020 raise ValueError(
1021 "Cannot specify both 'path_regex' and 'include_regex'. "
1022 "Use only 'include_regex' instead, 'path_regex' is deprecated."
1023 f"Got {path_regex=} and {include_regex=}"
1024 )
1025 include_regex = path_regex
1026 return include_regex
1028 def push(
1029 self,
1030 path: str,
1031 local_path_prefix: Optional[str] = None,
1032 force: bool = False,
1033 include_regex: Optional[Union[Pattern, str]] = None,
1034 exclude_regex: Optional[Union[Pattern, str]] = None,
1035 dryrun: bool = False,
1036 path_regex: Optional[Union[Pattern, str]] = None,
1037 ) -> TransactionSummary:
1038 """
1039 Upload files into the remote storage.
1040 Does not upload files for which the md5sum matches existing remote files.
1041 The remote path for uploading will be constructed from the remote_base_path and the provided path.
1042 The `local_path_prefix` serves for finding the directory on the local system or for stripping off
1043 parts of absolute paths if path is absolute, see examples below.
1045 Examples:
1046 1) path=foo/bar, local_path_prefix=None -->
1047 ./foo/bar uploaded to remote_base_path/foo/bar
1048 2) path=/home/foo/bar, local_path_prefix=None -->
1049 /home/foo/bar uploaded to remote_base_path/home/foo/bar
1050 3) path=bar, local_path_prefix=/home/foo -->
1051 /home/foo/bar uploaded to remote_base_path/bar
1052 4) path=/home/foo/bar, local_path_prefix=/home/foo -->
1053 /home/foo/bar uploaded to remote_base_path/bar (Same as 3)
1054 5) path=/home/baz/bar, local_path_prefix=/home/foo -->
1055 ValueError: Specified path=/home/baz/bar is not a child of local_path_prefix=/home/foo
1057 :param path: Path to the local object (file or directory) to be uploaded, may be absolute or relative.
1058 globs are supported as well, thus ``path`` may be a pattern like ``*.txt``.
1059 :param local_path_prefix: Prefix to be concatenated with ``path``
1060 :param force: If False, push will raise an error if an already existing remote file deviates from the local
1061 in its md5sum. If True, these files are overwritten.
1062 :param include_regex: If not None, only files with paths matching the regex will be pushed.
1063 Note that paths matched against the regex will be relative to ``local_path_prefix``.
1064 :param exclude_regex: If not None, only files with paths not matching the regex will be pushed. Takes precedence
1065 over ``include_regex``, i.e. if a file matches both regexes, it will be excluded.
1066 Note that paths matched against the regex will be relative to ``local_path_prefix``.
1067 :param dryrun: If True, simulates the push operation and returns the summary
1068 (with synced_files being an empty list).
1069 :param path_regex: DEPRECATED! Same as ``include_regex``.
1070 :return: An object describing the summary of the operation.
1071 """
1072 include_regex = self._handle_deprecated_path_regex(include_regex, path_regex)
1073 summary = self._get_push_summary(
1074 path,
1075 local_path_prefix,
1076 include_regex=include_regex,
1077 exclude_regex=exclude_regex,
1078 )
1079 return self._execute_sync_from_summary(summary, dryrun=dryrun, force=force)
1081 def delete(
1082 self,
1083 remote_path: str,
1084 include_regex: Optional[Union[Pattern, str]] = None,
1085 exclude_regex: Optional[Union[Pattern, str]] = None,
1086 path_regex: Optional[Union[Pattern, str]] = None,
1087 ) -> List[RemoteObjectProtocol]:
1088 """
1089 Deletes a file or a directory under the given path relative to local_base_dir. Use with caution!
1091 :param remote_path: remote path on storage bucket relative to the configured remote base path.
1092 :param include_regex: If not None only files with paths matching the regex will be deleted.
1093 :param exclude_regex: If not None only files with paths not matching the regex will be deleted.
1094 Takes precedence over ``include_regex``, i.e. if a file matches both regexes, it will be excluded.
1095 :param path_regex: DEPRECATED! Same as ``include_regex``.
1096 :return: list of remote objects referring to all deleted files
1097 """
1098 include_regex = self._handle_deprecated_path_regex(include_regex, path_regex)
1099 include_regex = _to_optional_pattern(include_regex)
1100 exclude_regex = _to_optional_pattern(exclude_regex)
1102 full_remote_path = self._full_remote_path(remote_path)
1104 remote_objects = cast(
1105 List[RemoteObjectProtocol], self.bucket.list_objects(full_remote_path)
1106 )
1107 if len(remote_objects) == 0:
1108 log.warning(
1109 f"No such remote file or directory: {full_remote_path}. Not deleting anything"
1110 )
1111 return []
1112 deleted_objects = []
1113 for remote_obj in remote_objects:
1114 if self._listed_due_to_name_collision(full_remote_path, remote_obj):
1115 log.debug(
1116 f"Skipping deletion of {remote_obj.name} as it was listed due to name collision"
1117 )
1118 continue
1120 relative_obj_path = self._get_relative_remote_path(remote_obj)
1121 if include_regex is not None and not include_regex.match(relative_obj_path):
1122 log.info(f"Skipping {relative_obj_path} due to regex {include_regex}")
1123 continue
1124 if exclude_regex is not None and exclude_regex.match(relative_obj_path):
1125 log.info(f"Skipping {relative_obj_path} due to regex {exclude_regex}")
1126 continue
1127 log.debug(f"Deleting {remote_obj.name}")
1128 self.bucket.delete_object(remote_obj) # type: ignore
1129 deleted_objects.append(remote_obj)
1130 return deleted_objects
1132 def list_objects(self, remote_path: str) -> List[RemoteObjectProtocol]:
1133 """
1134 :param remote_path: remote path on storage bucket relative to the configured remote base path.
1135 :return: list of remote objects under the remote path (multiple entries if the remote path is a directory)
1136 """
1137 full_remote_path = self._full_remote_path(remote_path)
1138 return self.bucket.list_objects(full_remote_path) # type: ignore