Skip to content

Commit 384672e

Browse files
committed
Adding Filtering dataloader
1 parent c78a235 commit 384672e

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

‎loreal_poc/dataloaders/base.py‎

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import numpy as np
77

8+
SingleLandmarkData = Tuple[np.ndarray, np.ndarray, Optional[Dict[Any, Any]]]
9+
BatchedLandmarkData = Tuple[Tuple[np.ndarray], np.ndarray, Tuple[Optional[Dict[Any, Any]]]]
10+
811

912
class DataIteratorBase(ABC):
1013
batch_size: int
@@ -61,12 +64,10 @@ def get_meta_with_default(self, idx: int) -> np.ndarray:
6164
marks = marks if marks is not None else self.meta_none()
6265
return marks
6366

64-
def get_single_element(self, idx) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]:
67+
def get_single_element(self, idx) -> SingleLandmarkData:
6568
return self.get_image(idx), self.get_marks_with_default(idx), self.get_meta_with_default(idx)
6669

67-
def __getitem__(
68-
self, idx: int
69-
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]: # (image, marks, meta)
70+
def __getitem__(self, idx: int) -> BatchedLandmarkData:
7071
return self._collate_fn(
7172
[self.get_single_element(i) for i in self.idx_sampler[idx * self.batch_size : (idx + 1) * self.batch_size]]
7273
)
@@ -84,30 +85,22 @@ def all_marks(self) -> np.ndarray: # (marks)
8485
def all_meta(self) -> List: # (meta)
8586
return [self.get_meta_with_default(idx) for idx in self.idx_sampler]
8687

87-
def __next__(self) -> Tuple[np.ndarray, np.ndarray]:
88+
def __next__(self) -> BatchedLandmarkData:
8889
if self.idx >= len(self):
8990
raise StopIteration
9091
elt = self[self.idx]
9192
self.idx += 1
9293
return elt
9394

94-
def _collate_fn(
95-
self, elements: List[Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]]
96-
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]:
97-
batched_elements = list(zip(*elements))
98-
batched_elements[1] = np.array(batched_elements[1])
99-
95+
def _collate_fn(self, elements: List[SingleLandmarkData]) -> BatchedLandmarkData:
96+
batched_elements = list(zip(*elements, strict=True))
10097
# INFO: Restore if we want to concatenate all meta under one dict instead of keeping them as records (list of dicts)
10198
# meta_keys = next((list(elt.keys()) for elt in batched_elements[2] if elt is not None), [])
10299
# batched_elements[2] = {
103100
# key: [meta[key] if (meta is not None and key in meta) else None for meta in batched_elements[2]]
104101
# for key in meta_keys
105102
# }
106-
107-
# if len(batched_elements[0]) != self.batch_size:
108-
# raise StopIteration
109-
110-
return batched_elements
103+
return batched_elements[0], np.array(batched_elements[1]), batched_elements[2]
111104

112105

113106
class DataLoaderBase(DataIteratorBase):

‎loreal_poc/dataloaders/wrappers.py‎

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22

33
import numpy as np
44

55
from ..marks.facial_parts import FacialPart
66
from ..transformation_functions.crop import crop_image_from_mark, crop_mark
7-
from .base import DataIteratorBase, DataLoaderWrapper
7+
from .base import DataIteratorBase, DataLoaderWrapper, SingleLandmarkData
88

99

1010
class CroppedDataLoader(DataLoaderWrapper):
@@ -51,3 +51,22 @@ def __getitem__(self, idx: int) -> Tuple[np.ndarray, Optional[np.ndarray], Optio
5151
if len(self._cache_idxs) > self._max_size:
5252
self._cache.pop(self._cache_idxs.pop(-1))
5353
return self._cache[idx]
54+
55+
56+
class FilteringDataLoader(DataLoaderWrapper):
57+
@property
58+
def name(self):
59+
return f"({self._wrapped_dataloader.name}) filtered using {self._predicate_name}"
60+
61+
@property
62+
def idx_sampler(self) -> np.ndarray:
63+
return self._reindex
64+
65+
def __init__(self, dataloader: DataIteratorBase, predicate: Callable[[SingleLandmarkData], bool]):
66+
super().__init__(dataloader)
67+
self._predicate_name = predicate.__name__ if hasattr(predicate, "__name__") else str(predicate)
68+
self._reindex = [
69+
idx
70+
for idx in self._wrapped_dataloader.idx_sampler
71+
if predicate(self._wrapped_dataloader.get_single_element(idx))
72+
]

0 commit comments

Comments
 (0)