Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
added gpu_id as parameter, refactored some notebooks
  • Loading branch information
rabah-khalek committed Jan 8, 2024
commit 03a62203d9eb0afb7dd4b1bfee5c9cd526a04f00
6 changes: 3 additions & 3 deletions examples/ffhq-filtering-caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"outputs": [],
"source": [
"from loreal_poc.dataloaders.loaders import DataLoaderFFHQ\n",
"from loreal_poc.dataloaders.wrappers import FilteringDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.wrappers import FilteredDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.base import DataLoaderWrapper\n",
"from loreal_poc.visualisation.draw import draw_marks\n",
"\n",
Expand Down Expand Up @@ -215,7 +215,7 @@
}
],
"source": [
"odds = FilteringDataLoader(dl, lambda elt: elt[2][\"type\"] == \"odd\")\n",
"odds = FilteredDataLoader(dl, lambda elt: elt[2][\"type\"] == \"odd\")\n",
"len(odds)"
]
},
Expand Down Expand Up @@ -256,7 +256,7 @@
}
],
"source": [
"evens = FilteringDataLoader(dl, lambda elt: elt[2][\"type\"] == \"even\")\n",
"evens = FilteredDataLoader(dl, lambda elt: elt[2][\"type\"] == \"even\")\n",
"len(evens)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions examples/ffhq-filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"outputs": [],
"source": [
"from loreal_poc.dataloaders.loaders import DataLoaderFFHQ\n",
"from loreal_poc.dataloaders.wrappers import FilteringDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.wrappers import FilteredDataLoader, CachedDataLoader\n",
"from loreal_poc.dataloaders.base import DataLoaderWrapper\n",
"from loreal_poc.visualisation.draw import draw_marks\n",
"\n",
Expand Down Expand Up @@ -35,7 +35,7 @@
"source": [
"img, marks, meta = dl[0]\n",
"\n",
"fdl = FilteringDataLoader(dl, lambda elt: elt[2][\"faceAttributes\"][\"headPose\"][\"yaw\"] < 0)\n",
"fdl = FilteredDataLoader(dl, lambda elt: elt[2][\"faceAttributes\"][\"headPose\"][\"yaw\"] < 0)\n",
"len(fdl)"
]
},
Expand Down
13 changes: 11 additions & 2 deletions loreal_poc/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,24 @@ def __init__(self, dataloader: DataIteratorBase, predicate: Callable[[SingleLand


class HeadPoseDataLoader(DataLoaderWrapper):
def __init__(self, dataloader: DataIteratorBase) -> None:
def __init__(self, dataloader: DataIteratorBase, gpu_id: int = -1) -> None:
"""A dataloader that estimates the head pose in images using the SixDRepNet model

Args:
dataloader (DataIteratorBase): the wrapped dataloader.
gpu_id (int, optional): Enable the usage of GPUs. Defaults to -1 (CPU).

Raises:
GiskardImportError: Error to signal a missing package
"""
try:
from sixdrepnet import SixDRepNet
except ImportError as e:
raise GiskardImportError("sixdrepnet") from e

super().__init__(dataloader)

self.pose_detection_model = SixDRepNet(gpu_id=-1)
self.pose_detection_model = SixDRepNet(gpu_id=gpu_id)

@property
def name(self):
Expand Down
1 change: 1 addition & 0 deletions loreal_poc/utils/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
class GiskardImportError(ImportError):
def __init__(self, missing_package: str) -> None:
super().__init__()
self.msg = f"The '{missing_package}' Python package is not installed; please execute 'pip install {missing_package}' to obtain it."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we call super().init ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is called right after, I try to check first for the missing dependency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure there are only 3 lines in this file, did I miss something ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I thought you were talking about the HeadPoseDataloader.

In fact I copied the above from https://github.com/Giskard-AI/giskard/blob/858c3c101382fb6f1933ed388984fc69cacd7195/giskard/core/errors.py#L22-L24

if the init just assigns msg than it doesn't really matter. let me check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the ImportError class's init:

class ImportError(Exception):
    def __init__(self, *args: object, name: str | None = ..., path: str | None = ...) -> None: 
      ...