-
Notifications
You must be signed in to change notification settings - Fork 3
Feature/gsk 2448 slicing dataloader #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| SingleLandmarkData = Tuple[np.ndarray, np.ndarray, Optional[Dict[Any, Any]]] | ||
| BatchedLandmarkData = Tuple[Tuple[np.ndarray], np.ndarray, Tuple[Optional[Dict[Any, Any]]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
| batched_elements[1] = np.array(batched_elements[1]) | ||
|
|
||
| def _collate_fn(self, elements: List[SingleLandmarkData]) -> BatchedLandmarkData: | ||
| batched_elements = list(zip(*elements, strict=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice touch with the strict
| # for key in meta_keys | ||
| # } | ||
|
|
||
| return batched_elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shame that I missed that one in the previous PR. We're definitely missing a lot of helpful tests.
| self._cached_functions = [ | ||
| lru_cache(maxsize=cache_size)(func) if should_cache else func | ||
| for should_cache, func in [ | ||
| (cache_img, self._wrapped_dataloader.get_image), | ||
| (cache_marks, self._wrapped_dataloader.get_marks), | ||
| (cache_meta, self._wrapped_dataloader.get_meta), | ||
| ] | ||
| ] | ||
|
|
||
| def get_image(self, idx: int) -> np.ndarray: | ||
| return self._cached_functions[0](idx) | ||
|
|
||
| def get_marks(self, idx: int) -> Optional[np.ndarray]: | ||
| return self._cached_functions[1](idx) | ||
|
|
||
| def get_meta(self, idx: int) -> Optional[Dict]: | ||
| return self._cached_functions[2](idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
way neater!!
| def __init__(self, dataloader: DataIteratorBase, predicate: Callable[[SingleLandmarkData], bool]): | ||
| super().__init__(dataloader) | ||
| self._predicate_name = predicate.__name__ if hasattr(predicate, "__name__") else str(predicate) | ||
| self._reindex = [ | ||
| idx | ||
| for idx in self._wrapped_dataloader.idx_sampler | ||
| if predicate(self._wrapped_dataloader.get_single_element(idx)) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
No description provided.