Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions examples/SixDRepNet_benchmark.ipynb

Large diffs are not rendered by default.

112 changes: 112 additions & 0 deletions examples/criteria3_face_orientations.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from loreal_poc.dataloaders.loaders import DataLoaderFFHQ\n",
"from loreal_poc.dataloaders.wrappers import DataLoaderWrapper, CachedDataLoader, FilteredDataLoader, HeadPoseDataLoader\n",
"from sixdrepnet import SixDRepNet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dl = DataLoaderFFHQ(\"ffhq\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 {'pitch': -3.2351706, 'yaw': 17.123957, 'roll': 2.4609118}\n",
"1 {'pitch': -9.634953, 'yaw': -5.263, 'roll': -1.0166222}\n",
"2 {'pitch': -12.489754, 'yaw': 7.760907, 'roll': 3.017409}\n",
"3 {'pitch': -8.279707, 'yaw': 0.58123016, 'roll': -0.17205757}\n",
"4 {'pitch': -8.762704, 'yaw': -4.9386187, 'roll': -0.59270185}\n",
"5 {'pitch': -2.703518, 'yaw': 3.4559636, 'roll': -3.7665286}\n",
"6 {'pitch': -13.6264105, 'yaw': -28.628199, 'roll': -2.5795803}\n",
"7 {'pitch': -17.597815, 'yaw': 16.94254, 'roll': 4.6466646}\n",
"8 {'pitch': -8.4020605, 'yaw': 6.840177, 'roll': -0.92642134}\n",
"9 {'pitch': 13.562258, 'yaw': 29.946465, 'roll': -6.7045293}\n",
"10 {'pitch': -14.822533, 'yaw': 3.8378444, 'roll': 0.5621732}\n"
]
}
],
"source": [
"cached_dl = CachedDataLoader(HeadPoseDataLoader(dl), cache_size=None, cache_img=False, cache_marks=False)\n",
"\n",
"for idx, (img, marks, meta) in enumerate(cached_dl):\n",
" print(idx, meta[0][\"headPose\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"head_pose_dl = FilteredDataLoader(cached_dl, lambda elt: elt[2][\"headPose\"][\"roll\"] > 0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 {'pitch': -3.2351706, 'yaw': 17.123957, 'roll': 2.4609118}\n",
"2 {'pitch': -12.489754, 'yaw': 7.760907, 'roll': 3.017409}\n",
"7 {'pitch': -17.597815, 'yaw': 16.94254, 'roll': 4.6466646}\n",
"10 {'pitch': -14.822533, 'yaw': 3.8378444, 'roll': 0.5621732}\n"
]
}
],
"source": [
"for idx, (img, marks, meta) in enumerate(head_pose_dl):\n",
" print(head_pose_dl._reindex[idx], meta[0][\"headPose\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 3 additions & 3 deletions examples/ffhq-filtering-caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"outputs": [],
"source": [
"from loreal_poc.dataloaders.loaders import DataLoaderFFHQ\n",
"from loreal_poc.dataloaders.wrappers import FilteringDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.wrappers import FilteredDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.base import DataLoaderWrapper\n",
"from loreal_poc.visualisation.draw import draw_marks\n",
"\n",
Expand Down Expand Up @@ -215,7 +215,7 @@
}
],
"source": [
"odds = FilteringDataLoader(dl, lambda elt: elt[2][\"type\"] == \"odd\")\n",
"odds = FilteredDataLoader(dl, lambda elt: elt[2][\"type\"] == \"odd\")\n",
"len(odds)"
]
},
Expand Down Expand Up @@ -256,7 +256,7 @@
}
],
"source": [
"evens = FilteringDataLoader(dl, lambda elt: elt[2][\"type\"] == \"even\")\n",
"evens = FilteredDataLoader(dl, lambda elt: elt[2][\"type\"] == \"even\")\n",
"len(evens)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions examples/ffhq-filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"outputs": [],
"source": [
"from loreal_poc.dataloaders.loaders import DataLoaderFFHQ\n",
"from loreal_poc.dataloaders.wrappers import FilteringDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.wrappers import FilteredDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.base import DataLoaderWrapper\n",
"from loreal_poc.visualisation.draw import draw_marks\n",
"\n",
Expand Down Expand Up @@ -35,7 +35,7 @@
"source": [
"img, marks, meta = dl[0]\n",
"\n",
"fdl = FilteringDataLoader(dl, lambda elt: elt[2][\"faceAttributes\"][\"headPose\"][\"yaw\"] < 0)\n",
"fdl = FilteredDataLoader(dl, lambda elt: elt[2][\"faceAttributes\"][\"headPose\"][\"yaw\"] < 0)\n",
"len(fdl)"
]
},
Expand Down
32 changes: 31 additions & 1 deletion loreal_poc/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
resize_image,
resize_marks,
)
from ..utils.errors import GiskardImportError
from .base import DataIteratorBase, DataLoaderWrapper, SingleLandmarkData


Expand Down Expand Up @@ -138,7 +139,7 @@ def get_image(self, idx: int) -> np.ndarray:
return cv2.cvtColor(image, self._mode)


class FilteringDataLoader(DataLoaderWrapper):
class FilteredDataLoader(DataLoaderWrapper):
@property
def name(self):
return f"({self._wrapped_dataloader.name}) filtered using '{self._predicate_name}'"
Expand All @@ -155,3 +156,32 @@ def __init__(self, dataloader: DataIteratorBase, predicate: Callable[[SingleLand
for idx in self._wrapped_dataloader.idx_sampler
if predicate(self._wrapped_dataloader.get_single_element(idx))
]


class HeadPoseDataLoader(DataLoaderWrapper):
def __init__(self, dataloader: DataIteratorBase, gpu_id: int = -1) -> None:
"""A dataloader that estimates the head pose in images using the SixDRepNet model

Args:
dataloader (DataIteratorBase): the wrapped dataloader.
gpu_id (int, optional): Enable the usage of GPUs. Defaults to -1 (CPU).

Raises:
GiskardImportError: Error to signal a missing package
"""
try:
from sixdrepnet import SixDRepNet
except ImportError as e:
raise GiskardImportError("sixdrepnet") from e

super().__init__(dataloader)

self.pose_detection_model = SixDRepNet(gpu_id=gpu_id)

@property
def name(self):
return f"({self._wrapped_dataloader.name}) with head pose estimation'"

def get_meta(self, idx):
pitch, yaw, roll = self.pose_detection_model.predict(self.get_image(idx))
return {"headPose": {"pitch": pitch[0], "yaw": -yaw[0], "roll": roll[0]}}
Empty file added loreal_poc/utils/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions loreal_poc/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class GiskardImportError(ImportError):
def __init__(self, missing_package: str) -> None:
super().__init__()
self.msg = f"The '{missing_package}' Python package is not installed; please execute 'pip install {missing_package}' to obtain it."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we call super().init ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is called right after, I try to check first for the missing dependency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure there are only 3 lines in this file, did I miss something ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I thought you were talking about the HeadPoseDataloader.

In fact I copied the above from https://github.com/Giskard-AI/giskard/blob/858c3c101382fb6f1933ed388984fc69cacd7195/giskard/core/errors.py#L22-L24

if the init just assigns msg than it doesn't really matter. let me check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the ImportError class's init:

class ImportError(Exception):
    def __init__(self, *args: object, name: str | None = ..., path: str | None = ...) -> None: 
      ...
67 changes: 55 additions & 12 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"pre-commit>=2.19.0",
"ruff",
"isort",
"sixdrepnet>=0.1.6",
]

[tool.ruff]
Expand Down
12 changes: 10 additions & 2 deletions tests/dataloaders/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from loreal_poc.dataloaders.wrappers import (
CachedDataLoader,
CroppedDataLoader,
FilteringDataLoader,
FilteredDataLoader,
HeadPoseDataLoader,
ResizedDataLoader,
)
from loreal_poc.marks.facial_parts import FacialParts
Expand Down Expand Up @@ -59,7 +60,7 @@ def is_odd(elt: SingleLandmarkData) -> bool:
def test_filtering_dataloader():
dl = WithMetaDataLoader(DataloaderForTest("example", length=10))

filtered = FilteringDataLoader(dl, predicate=is_odd)
filtered = FilteredDataLoader(dl, predicate=is_odd)

assert len(filtered) == 5
assert "filtered using 'is_odd'" in filtered.name
Expand All @@ -83,3 +84,10 @@ def test_resized_dataloader():
for resized_img, resized_marks, _ in resized:
assert resized_img[0].shape[0] == 300
assert resized_img[0].shape[1] == 500


def test_headpose_dataloader(dataset_ffhq):
head_pose_dl = FilteredDataLoader(HeadPoseDataLoader(dataset_ffhq), lambda elt: elt[2]["headPose"]["roll"] > 0)

assert len(head_pose_dl) == 4
assert np.array_equal(head_pose_dl._reindex, [0, 2, 7, 10])