Skip to content

Conversation

@Hartorn
Copy link
Member

@Hartorn Hartorn commented Jan 4, 2024

No description provided.

@Hartorn Hartorn requested a review from rabah-khalek January 4, 2024 10:50
@Hartorn Hartorn self-assigned this Jan 4, 2024
Base automatically changed from fix-batching to feature/gsk-2424-improve-the-hierarchy-structure-see-ffhq January 4, 2024 10:55
Base automatically changed from feature/gsk-2424-improve-the-hierarchy-structure-see-ffhq to main January 4, 2024 13:34
Comment on lines +8 to +9
SingleLandmarkData = Tuple[np.ndarray, np.ndarray, Optional[Dict[Any, Any]]]
BatchedLandmarkData = Tuple[Tuple[np.ndarray], np.ndarray, Tuple[Optional[Dict[Any, Any]]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

batched_elements[1] = np.array(batched_elements[1])

def _collate_fn(self, elements: List[SingleLandmarkData]) -> BatchedLandmarkData:
batched_elements = list(zip(*elements, strict=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice touch with the strict

# for key in meta_keys
# }

return batched_elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shame that I missed that one in the previous PR. We're definitely missing a lot of helpful tests.

Comment on lines +44 to +60
self._cached_functions = [
lru_cache(maxsize=cache_size)(func) if should_cache else func
for should_cache, func in [
(cache_img, self._wrapped_dataloader.get_image),
(cache_marks, self._wrapped_dataloader.get_marks),
(cache_meta, self._wrapped_dataloader.get_meta),
]
]

def get_image(self, idx: int) -> np.ndarray:
return self._cached_functions[0](idx)

def get_marks(self, idx: int) -> Optional[np.ndarray]:
return self._cached_functions[1](idx)

def get_meta(self, idx: int) -> Optional[Dict]:
return self._cached_functions[2](idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

way neater!!

Comment on lines +76 to +83
def __init__(self, dataloader: DataIteratorBase, predicate: Callable[[SingleLandmarkData], bool]):
super().__init__(dataloader)
self._predicate_name = predicate.__name__ if hasattr(predicate, "__name__") else str(predicate)
self._reindex = [
idx
for idx in self._wrapped_dataloader.idx_sampler
if predicate(self._wrapped_dataloader.get_single_element(idx))
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

@rabah-khalek rabah-khalek marked this pull request as ready for review January 4, 2024 15:07
@rabah-khalek rabah-khalek merged commit e1e3542 into main Jan 4, 2024
@rabah-khalek rabah-khalek deleted the feature/gsk-2448-slicing-dataloader branch January 4, 2024 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants