Skip to content

Commit 3dda50c

Browse files
authored
Merge pull request #36 from Giskard-AI/meta-wrapper
Refactoring with meta wrapper
2 parents 1c7f342 + 36947bc commit 3dda50c

File tree

20 files changed

+808
-786
lines changed

20 files changed

+808
-786
lines changed

‎examples/landmark_detection/criterias/criteria3_face_orientations.ipynb‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"cached_dl = CachedDataLoader(HeadPoseDataLoader(dl), cache_size=None, cache_img=False, cache_labels=False)\n",
4949
"\n",
5050
"for idx, (img, marks, meta) in enumerate(cached_dl):\n",
51-
" print(idx, meta[0][\"headPose\"])"
51+
" print(idx, meta[0])"
5252
]
5353
},
5454
{
@@ -57,7 +57,7 @@
5757
"metadata": {},
5858
"outputs": [],
5959
"source": [
60-
"head_pose_dl = FilteredDataLoader(cached_dl, lambda elt: elt[2][\"headPose\"][\"roll\"] > 0)"
60+
"head_pose_dl = FilteredDataLoader(cached_dl, lambda elt: elt[2].get_includes(\"roll\") > 0)"
6161
]
6262
},
6363
{
@@ -78,7 +78,7 @@
7878
],
7979
"source": [
8080
"for idx, (img, marks, meta) in enumerate(head_pose_dl):\n",
81-
" print(head_pose_dl._reindex[idx], meta[0][\"headPose\"])"
81+
" print(head_pose_dl._reindex[idx], meta[0])"
8282
]
8383
},
8484
{

‎examples/landmark_detection/criterias/criteria4_ethnicity.ipynb‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"cached_dl = CachedDataLoader(ethnicity_dl, cache_size=None, cache_img=False, cache_labels=False)\n",
4343
"dl.name\n",
4444
"\n",
45-
"asians = FilteredDataLoader(cached_dl, lambda elt: elt[2][\"ethnicity\"] == \"asian\")\n",
45+
"asians = FilteredDataLoader(cached_dl, lambda elt: elt[2].get_includes(\"ethnicity\") == \"asian\")\n",
4646
"asians._reindex"
4747
]
4848
}

‎examples/landmark_detection/demo/ethnicity_criteria.ipynb‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
"cached_dl = CachedDataLoader(ethnicity_dl, cache_size=None, cache_img=False, cache_labels=False)\n",
5555
"dl.name\n",
5656
"\n",
57-
"asians = FilteredDataLoader(cached_dl, lambda elt: elt[2][\"ethnicity\"] == \"asian\")\n",
58-
"whites = FilteredDataLoader(cached_dl, lambda elt: elt[2][\"ethnicity\"] == \"white\")"
57+
"asians = FilteredDataLoader(cached_dl, lambda elt: elt[2].get_includes(\"ethnicity\") == \"asian\")\n",
58+
"whites = FilteredDataLoader(cached_dl, lambda elt: elt[2].get_includes(\"ethnicity\") == \"white\")"
5959
]
6060
},
6161
{

‎examples/landmark_detection/demo/report_example.ipynb‎

Lines changed: 417 additions & 445 deletions
Large diffs are not rendered by default.

‎examples/landmark_detection/ffhq/ffhq-filtering-caching.ipynb‎

Lines changed: 38 additions & 215 deletions
Large diffs are not rendered by default.

‎examples/landmark_detection/ffhq/ffhq-filtering.ipynb‎

Lines changed: 8 additions & 32 deletions
Large diffs are not rendered by default.

‎examples/landmark_detection/ffhq/ffhq.ipynb‎

Lines changed: 23 additions & 38 deletions
Large diffs are not rendered by default.

‎examples/landmark_detection/report/master.ipynb‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,11 @@
238238
"\n",
239239
"\n",
240240
"def positive_roll(elt):\n",
241-
" return elt[2][\"headPose\"][\"roll\"] > 0\n",
241+
" return elt[2].get_includes(\"roll\") > 0\n",
242242
"\n",
243243
"\n",
244244
"def negative_roll(elt):\n",
245-
" return elt[2][\"headPose\"][\"roll\"] < 0\n",
245+
" return elt[2].get_includes(\"roll\") < 0\n",
246246
"\n",
247247
"\n",
248248
"head_poses = [positive_roll, negative_roll]\n",
@@ -288,7 +288,7 @@
288288
"\n",
289289
"\n",
290290
"def white_ethnicity(elt):\n",
291-
" return elt[2][\"ethnicity\"] == \"white\"\n",
291+
" return elt[2].get_includes(\"ethnicity\") == \"white\"\n",
292292
"\n",
293293
"\n",
294294
"ethnicities = [white_ethnicity]\n",

‎examples/landmark_detection/report/report.ipynb‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@
6767
"\n",
6868
"# head pose filtering\n",
6969
"def positive_roll(elt):\n",
70-
" return elt[2][\"headPose\"][\"roll\"] > 0\n",
70+
" return elt[2].get_includes(\"roll\") > 0\n",
7171
"\n",
7272
"\n",
7373
"def negative_roll(elt):\n",
74-
" return elt[2][\"headPose\"][\"roll\"] < 0\n",
74+
" return elt[2].get_includes(\"roll\") < 0\n",
7575
"\n",
7676
"\n",
7777
"cached_dl = CachedDataLoader(HeadPoseDataLoader(dl_ref), cache_size=None, cache_img=False, cache_labels=False)\n",
@@ -81,11 +81,11 @@
8181
"\n",
8282
"# ethnicity filtering\n",
8383
"def white_ethnicity(elt):\n",
84-
" return elt[2][\"ethnicity\"] == \"white\"\n",
84+
" return elt[2].get_includes(\"ethnicity\") == \"white\"\n",
8585
"\n",
8686
"\n",
8787
"def latino_ethnicity(elt):\n",
88-
" return elt[2][\"ethnicity\"] == \"latino hispanic\"\n",
88+
" return elt[2].get_includes(\"ethnicity\") == \"latino hispanic\"\n",
8989
"\n",
9090
"\n",
9191
"cached_dl = CachedDataLoader(\n",

‎giskard_vision/core/dataloaders/base.py‎

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import math
22
from abc import ABC, abstractmethod
3-
from typing import Any, Dict, List, Optional, Tuple
3+
from typing import List, Optional
44

55
import numpy as np
66

7-
SingleLabel = Tuple[np.ndarray, Any, Optional[Dict[Any, Any]]]
8-
BatchedLabels = Tuple[Tuple[np.ndarray], Any, Tuple[Optional[Dict[Any, Any]]]]
7+
from ..types import TypesBase
98

109

1110
class DataIteratorBase(ABC):
@@ -104,7 +103,7 @@ def labels_none(self) -> Optional[np.ndarray]:
104103
"""
105104
return None
106105

107-
def meta_none(self) -> Optional[Dict]:
106+
def meta_none(self) -> Optional[TypesBase.meta]:
108107
"""
109108
Returns default for meta data if it is None.
110109
@@ -125,15 +124,15 @@ def get_labels(self, idx: int) -> Optional[np.ndarray]:
125124
"""
126125
return None
127126

128-
def get_meta(self, idx: int) -> Optional[Dict]:
127+
def get_meta(self, idx: int) -> Optional[TypesBase.meta]:
129128
"""
130129
Gets meta information (for a single image) for a specific index.
131130
132131
Args:
133132
idx (int): Index of the image.
134133
135134
Returns:
136-
Optional[Dict]: Meta information for the given index.
135+
Optional[TypesBase.meta]: Meta information for the given index.
137136
"""
138137
return None
139138

@@ -163,27 +162,27 @@ def get_meta_with_default(self, idx: int) -> np.ndarray:
163162
meta = self.get_meta(idx)
164163
return meta if meta is not None else self.meta_none()
165164

166-
def get_single_element(self, idx) -> SingleLabel:
165+
def get_single_element(self, idx) -> TypesBase.single_data:
167166
"""
168167
Gets a single element as a tuple of (image, labels, meta) for a specific index.
169168
170169
Args:
171170
idx (int): Index of the image.
172171
173172
Returns:
174-
SingleLabel: Tuple containing image, labels, and meta information.
173+
TypesBase.single_data: Tuple containing image, labels, and meta information.
175174
"""
176175
return self.get_image(idx), self.get_labels_with_default(idx), self.get_meta_with_default(idx)
177176

178-
def __getitem__(self, idx: int) -> BatchedLabels:
177+
def __getitem__(self, idx: int) -> TypesBase.batched_data:
179178
"""
180179
Gets a batch of elements for a specific index.
181180
182181
Args:
183182
idx (int): Index of the batch.
184183
185184
Returns:
186-
BatchedLabels: Batched data containing images, labels, and meta information.
185+
TypesBase.batched_data: Batched data containing images, labels, and meta information.
187186
"""
188187
return self._collate_fn(
189188
[self.get_single_element(i) for i in self.idx_sampler[idx * self.batch_size : (idx + 1) * self.batch_size]]
@@ -220,28 +219,28 @@ def all_meta(self) -> List:
220219
"""
221220
return [self.get_meta_with_default(idx) for idx in self.idx_sampler]
222221

223-
def __next__(self) -> BatchedLabels:
222+
def __next__(self) -> TypesBase.batched_data:
224223
"""
225224
Gets the next batch of elements.
226225
227226
Returns:
228-
BatchedLabels: Batched data containing images, labels, and meta information.
227+
TypesBase.batched_data: Batched data containing images, labels, and meta information.
229228
"""
230229
if self.idx >= len(self):
231230
raise StopIteration
232231
elt = self[self.idx]
233232
self.idx += 1
234233
return elt
235234

236-
def _collate_fn(self, elements: List[SingleLabel]) -> BatchedLabels:
235+
def _collate_fn(self, elements: List[TypesBase.single_data]) -> TypesBase.batched_data:
237236
"""
238237
Collates a list of single elements into a batched element.
239238
240239
Args:
241-
elements (List[SingleLabel]): List of single elements.
240+
elements (List[TypesBase.single_data]): List of single elements.
242241
243242
Returns:
244-
BatchedLabels: Batched data containing images, labels, and meta information.
243+
TypesBase.batched_data: Batched data containing images, labels, and meta information.
245244
"""
246245
batched_elements = list(zip(*elements, strict=True))
247246
return batched_elements[0], np.array(batched_elements[1]), batched_elements[2]
@@ -311,15 +310,15 @@ def get_labels(self, idx: int) -> Optional[np.ndarray]:
311310
"""
312311
return self._wrapped_dataloader.get_labels(idx)
313312

314-
def get_meta(self, idx: int) -> Optional[Dict]:
313+
def get_meta(self, idx: int) -> Optional[TypesBase.meta]:
315314
"""
316315
Gets meta information from the wrapped data loader.
317316
318317
Args:
319318
idx (int): Index of the data.
320319
321320
Returns:
322-
Optional[Dict]: Meta information from the wrapped data loader.
321+
Optional[TypesBase.meta]: Meta information from the wrapped data loader.
323322
"""
324323
return self._wrapped_dataloader.get_meta(idx)
325324

0 commit comments

Comments
 (0)