Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions examples/SixDRepNet_benchmark.ipynb

Large diffs are not rendered by default.

112 changes: 112 additions & 0 deletions examples/criteria3_face_orientations.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from loreal_poc.dataloaders.loaders import DataLoaderFFHQ\n",
"from loreal_poc.dataloaders.wrappers import DataLoaderWrapper, CachedDataLoader, FilteredDataLoader, HeadPoseDataLoader\n",
"from sixdrepnet import SixDRepNet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dl = DataLoaderFFHQ(\"ffhq\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 {'pitch': -3.2351706, 'yaw': 17.123957, 'roll': 2.4609118}\n",
"1 {'pitch': -9.634953, 'yaw': -5.263, 'roll': -1.0166222}\n",
"2 {'pitch': -12.489754, 'yaw': 7.760907, 'roll': 3.017409}\n",
"3 {'pitch': -8.279707, 'yaw': 0.58123016, 'roll': -0.17205757}\n",
"4 {'pitch': -8.762704, 'yaw': -4.9386187, 'roll': -0.59270185}\n",
"5 {'pitch': -2.703518, 'yaw': 3.4559636, 'roll': -3.7665286}\n",
"6 {'pitch': -13.6264105, 'yaw': -28.628199, 'roll': -2.5795803}\n",
"7 {'pitch': -17.597815, 'yaw': 16.94254, 'roll': 4.6466646}\n",
"8 {'pitch': -8.4020605, 'yaw': 6.840177, 'roll': -0.92642134}\n",
"9 {'pitch': 13.562258, 'yaw': 29.946465, 'roll': -6.7045293}\n",
"10 {'pitch': -14.822533, 'yaw': 3.8378444, 'roll': 0.5621732}\n"
]
}
],
"source": [
"cached_dl = CachedDataLoader(HeadPoseDataLoader(dl), cache_size=None, cache_img=False, cache_marks=False)\n",
"\n",
"for idx, (img, marks, meta) in enumerate(cached_dl):\n",
" print(idx, meta[0][\"headPose\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"head_pose_dl = FilteredDataLoader(cached_dl, lambda elt: elt[2][\"headPose\"][\"roll\"] > 0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 {'pitch': -3.2351706, 'yaw': 17.123957, 'roll': 2.4609118}\n",
"2 {'pitch': -12.489754, 'yaw': 7.760907, 'roll': 3.017409}\n",
"7 {'pitch': -17.597815, 'yaw': 16.94254, 'roll': 4.6466646}\n",
"10 {'pitch': -14.822533, 'yaw': 3.8378444, 'roll': 0.5621732}\n"
]
}
],
"source": [
"for idx, (img, marks, meta) in enumerate(head_pose_dl):\n",
" print(head_pose_dl._reindex[idx], meta[0][\"headPose\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
23 changes: 22 additions & 1 deletion loreal_poc/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
resize_image,
resize_marks,
)
from ..utils.errors import GiskardImportError
from .base import DataIteratorBase, DataLoaderWrapper, SingleLandmarkData


Expand Down Expand Up @@ -138,7 +139,7 @@ def get_image(self, idx: int) -> np.ndarray:
return cv2.cvtColor(image, self._mode)


class FilteringDataLoader(DataLoaderWrapper):
class FilteredDataLoader(DataLoaderWrapper):
@property
def name(self):
return f"({self._wrapped_dataloader.name}) filtered using '{self._predicate_name}'"
Expand All @@ -155,3 +156,23 @@ def __init__(self, dataloader: DataIteratorBase, predicate: Callable[[SingleLand
for idx in self._wrapped_dataloader.idx_sampler
if predicate(self._wrapped_dataloader.get_single_element(idx))
]


class HeadPoseDataLoader(DataLoaderWrapper):
def __init__(self, dataloader: DataIteratorBase) -> None:
try:
from sixdrepnet import SixDRepNet
except ImportError as e:
raise GiskardImportError("sixdrepnet") from e

super().__init__(dataloader)

self.pose_detection_model = SixDRepNet(gpu_id=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 means no gpu ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -1 is cpu, I'll make this one a parameter, just in case people want to run it on gpu. Good catch.


@property
def name(self):
return f"({self._wrapped_dataloader.name}) with head pose estimation'"

def get_meta(self, idx):
pitch, yaw, roll = self.pose_detection_model.predict(self.get_image(idx))
return {"headPose": {"pitch": pitch[0], "yaw": -yaw[0], "roll": roll[0]}}
Empty file added loreal_poc/utils/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions loreal_poc/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class GiskardImportError(ImportError):
def __init__(self, missing_package: str) -> None:
self.msg = f"The '{missing_package}' Python package is not installed; please execute 'pip install {missing_package}' to obtain it."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we call super().init ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is called right after, I try to check first for the missing dependency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure there are only 3 lines in this file, did I miss something ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I thought you were talking about the HeadPoseDataloader.

In fact I copied the above from https://github.com/Giskard-AI/giskard/blob/858c3c101382fb6f1933ed388984fc69cacd7195/giskard/core/errors.py#L22-L24

if the init just assigns msg than it doesn't really matter. let me check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the ImportError class's init:

class ImportError(Exception):
    def __init__(self, *args: object, name: str | None = ..., path: str | None = ...) -> None: 
      ...
67 changes: 55 additions & 12 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"pre-commit>=2.19.0",
"ruff",
"isort",
"sixdrepnet>=0.1.6",
]

[tool.ruff]
Expand Down
12 changes: 10 additions & 2 deletions tests/dataloaders/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from loreal_poc.dataloaders.wrappers import (
CachedDataLoader,
CroppedDataLoader,
FilteringDataLoader,
FilteredDataLoader,
HeadPoseDataLoader,
ResizedDataLoader,
)
from loreal_poc.marks.facial_parts import FacialParts
Expand Down Expand Up @@ -59,7 +60,7 @@ def is_odd(elt: SingleLandmarkData) -> bool:
def test_filtering_dataloader():
dl = WithMetaDataLoader(DataloaderForTest("example", length=10))

filtered = FilteringDataLoader(dl, predicate=is_odd)
filtered = FilteredDataLoader(dl, predicate=is_odd)

assert len(filtered) == 5
assert "filtered using 'is_odd'" in filtered.name
Expand All @@ -83,3 +84,10 @@ def test_resized_dataloader():
for resized_img, resized_marks, _ in resized:
assert resized_img[0].shape[0] == 300
assert resized_img[0].shape[1] == 500


def test_headpose_dataloader(dataset_ffhq):
head_pose_dl = FilteredDataLoader(HeadPoseDataLoader(dataset_ffhq), lambda elt: elt[2]["headPose"]["roll"] > 0)

assert len(head_pose_dl) == 4
assert np.array_equal(head_pose_dl._reindex, [0, 2, 7, 10])