Skip to content

Commit 976a170

Browse files
committed
Improve a bit FFHQ structure
1 parent 70993c0 commit 976a170

File tree

4 files changed

+55
-56
lines changed

4 files changed

+55
-56
lines changed

‎loreal_poc/dataloaders/base.py‎

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ def __len__(self) -> int:
3030
def get_image(self, idx: int) -> np.ndarray:
3131
...
3232

33-
@property
34-
def marks_none(self) -> Optional[np.ndarray]:
33+
@classmethod
34+
def marks_none(cls) -> Optional[np.ndarray]:
3535
return None
3636

37-
@property
38-
def meta_none(self) -> Optional[Dict]:
37+
@classmethod
38+
def meta_none(cls) -> Optional[Dict]:
3939
return None
4040

4141
def get_marks(self, idx: int) -> Optional[np.ndarray]:
@@ -49,9 +49,9 @@ def __getitem__(
4949
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]: # (image, marks, meta)
5050
idx = self.idx_sampler[idx]
5151
marks = self.get_marks(idx)
52-
marks = marks if marks is not None else self.marks_none
52+
marks = marks if marks is not None else self.marks_none()
5353
meta = self.get_meta(idx)
54-
meta = meta if meta is not None else self.meta_none
54+
meta = meta if meta is not None else self.meta_none()
5555
return self.get_image(idx), marks, meta
5656

5757
@property
@@ -114,16 +114,21 @@ def __init__(
114114
collate_fn: Optional[Callable] = None,
115115
) -> None:
116116
super().__init__(name, batch_size=batch_size)
117+
# Get the images paths
117118
images_dir_path = self._get_absolute_local_path(images_dir_path)
118-
landmarks_dir_path = self._get_absolute_local_path(landmarks_dir_path)
119-
120119
self.image_paths = self._get_all_paths_based_on_suffix(images_dir_path, self.image_suffix)
121-
self.marks_paths = self._get_all_paths_based_on_suffix(landmarks_dir_path, self.marks_suffix)
122-
if len(self.marks_paths) != len(self.image_paths):
123-
raise ValueError(
124-
f"{self.__class__.__name__}: Only {len(self.marks_paths)} found "
125-
f"for {len(self.marks_paths)} of the images."
126-
)
120+
121+
self.marks_paths = None
122+
# If landmarks folder is not none, we should load them
123+
# Else, the get marks method should be overridden
124+
if landmarks_dir_path is not None:
125+
landmarks_dir_path = self._get_absolute_local_path(landmarks_dir_path)
126+
self.marks_paths = self._get_all_paths_based_on_suffix(landmarks_dir_path, self.marks_suffix)
127+
if len(self.marks_paths) != len(self.image_paths):
128+
raise ValueError(
129+
f"{self.__class__.__name__}: Only {len(self.marks_paths)} found "
130+
f"for {len(self.marks_paths)} of the images."
131+
)
127132

128133
self.shuffle = shuffle
129134

@@ -151,9 +156,7 @@ def _get_absolute_local_path(self, local_path: Union[str, Path]) -> Path:
151156

152157
@classmethod
153158
def _get_all_paths_based_on_suffix(cls, dir_path: Path, suffix: str) -> List[Path]:
154-
all_paths_with_suffix = list(
155-
sorted([p for p in dir_path.iterdir() if p.suffix == suffix], key=lambda p: str(p))
156-
)
159+
all_paths_with_suffix = list(sorted([p for p in dir_path.iterdir() if p.suffix == suffix], key=str))
157160
if len(all_paths_with_suffix) == 0:
158161
raise ValueError(
159162
f"{cls.__class__.__name__}: Landmarks with suffix {suffix}"
@@ -162,11 +165,11 @@ def _get_all_paths_based_on_suffix(cls, dir_path: Path, suffix: str) -> List[Pat
162165
return all_paths_with_suffix
163166

164167
def __len__(self) -> int:
165-
return math.floor(len(self.image_paths) / self.batch_size)
168+
return math.ceil(len(self.image_paths) / self.batch_size)
166169

167-
@property
168-
def marks_none(self):
169-
return np.full((self.n_landmarks, self.n_landmarks), np.nan)
170+
@classmethod
171+
def marks_none(cls) -> np.ndarray:
172+
return np.full((cls.n_landmarks, cls.n_landmarks), np.nan)
170173

171174
def get_image(self, idx: int) -> np.ndarray:
172175
return self._load_and_validate_image(self.image_paths[idx])

‎loreal_poc/dataloaders/loaders.py‎

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Callable, Dict, Optional, Union
3+
from typing import Callable, Dict, List, Optional, Union
44

55
import cv2
66
import numpy as np
@@ -72,34 +72,30 @@ def __init__(
7272
rng_seed: Optional[int] = None,
7373
collate_fn: Optional[Callable] = None,
7474
) -> None:
75-
# TODO!!: super __init__!!
75+
super().__init__(
76+
images_dir_path=dir_path,
77+
landmarks_dir_path=None,
78+
name=name,
79+
batch_size=batch_size,
80+
collate_fn=collate_fn,
81+
rng_seed=rng_seed,
82+
shuffle=shuffle,
83+
meta=None,
84+
)
85+
with (Path(dir_path) / "ffhq-dataset-meta.json").open(encoding="utf-8") as fp:
86+
self.landmarks: Dict[int, List[List[float]]] = {
87+
int(k): v["image"]["face_landmarks"] for k, v in json.load(fp)
88+
}
89+
7690
images_dir_path = self._get_absolute_local_path(dir_path)
7791
self.image_paths = self._get_all_paths_based_on_suffix(images_dir_path, self.image_suffix)
78-
f = open(Path(dir_path) / "ffhq-dataset-meta.json")
79-
self.landmarks_data = json.load(f)
80-
f.close()
81-
82-
# TODO: No good
83-
self.name = name
84-
self.batch_size = batch_size
85-
self.shuffle = shuffle
86-
87-
self.rng = np.random.default_rng(rng_seed)
88-
89-
self.idx_sampler = list(range(len(self.image_paths)))
90-
if shuffle:
91-
self.rng.shuffle(self.idx_sampler)
92-
93-
if collate_fn is not None:
94-
self._collate_fn = collate_fn
9592

9693
def get_marks(self, idx: int) -> Optional[np.ndarray]:
97-
return np.array(self.landmarks_data[str(idx)]["image"]["face_landmarks"])
94+
return np.array(self.landmarks[idx])
9895

9996
def get_meta(self, idx: int) -> Optional[Dict]:
100-
f = open(f"ffhq/{idx:05d}.json")
101-
meta = json.load(f)
102-
f.close()
97+
with Path(f"ffhq/{idx:05d}.json").open(encoding="utf-8") as fp:
98+
meta = json.load(fp)
10399
return meta[0]
104100

105101
@classmethod
@@ -114,5 +110,6 @@ def load_image_from_file(cls, image_file: Path) -> np.ndarray:
114110
"""
115111
return cv2.imread(str(image_file))
116112

113+
@classmethod
117114
def load_marks_from_file(cls, mark_file: Path) -> np.ndarray:
118-
pass
115+
raise NotImplementedError("Should not be called for FFHQ")

‎pyproject.toml‎

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "loreal-poc"
77
version = "2023.11.27"
88
description = "Assessing the quality of L'Oreal's facial landmark models"
9-
authors = [
10-
{name = "Rabah Abdul Khalek", email = "rabah@giskard.ai"},
11-
]
9+
authors = [{ name = "Rabah Abdul Khalek", email = "rabah@giskard.ai" }]
1210
dependencies = [
1311
"pillow>=10.1.0", # just for drawing
1412
"opencv-python",
@@ -23,20 +21,20 @@ notebook = "jupyter notebook --ip 0.0.0.0 --port 8888 --no-browser --notebook-di
2321
format.cmd = "bash -c 'ruff ./loreal_poc ./tests --fix && black ./loreal_poc ./examples ./tests && isort ./loreal_poc ./tests'"
2422
check-format.cmd = "bash -c 'ruff ./loreal_poc ./tests && black --check ./loreal_poc ./examples ./tests && isort --check ./loreal_poc ./tests'"
2523
test.cmd = "pytest tests/ -c pyproject.toml --disable-warnings -vvv --durations=0"
26-
check-notebook="bash -c 'cd ./examples && pdm run jupyter nbconvert --to script -y *.ipynb && find . -type f | grep -e \".py$\" | xargs -I {} echo \"pdm run python {} && echo \"Notebook {} OK\" || exit 1\" | sh'"
24+
check-notebook = "bash -c 'cd ./examples && pdm run jupyter nbconvert --to script -y *.ipynb && find . -type f | grep -e \".py$\" | xargs -I {} echo \"pdm run python {} && echo \"Notebook {} OK\" || exit 1\" | sh'"
2725

2826
[tool.pdm.dev-dependencies]
2927
dev = [
3028
"face-alignment",
31-
"opencv-contrib-python", # needed for lbfmodel
29+
"opencv-contrib-python", # needed for lbfmodel
3230
"notebook",
3331
"matplotlib",
3432
"black[jupyter]>=23.7.0",
3533
"pytest>=7.4.0",
3634
"pip>=23.2.1",
3735
"pre-commit>=2.19.0",
3836
"ruff",
39-
"isort"
37+
"isort",
4038
]
4139

4240
[tool.ruff]
@@ -65,5 +63,6 @@ exclude = '''
6563
| dist
6664
| env
6765
| venv
66+
| .history
6867
)/
69-
'''
68+
'''

‎tests/dataloaders/test_base.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def __len__(self) -> int:
4242
def get_image(self, idx: int) -> np.ndarray:
4343
return self.dataset[idx]
4444

45-
@property
46-
def marks_none(self):
45+
@classmethod
46+
def marks_none(cls):
4747
return np.full((68, 2), np.nan)
4848

49-
@property
50-
def meta_none(self):
49+
@classmethod
50+
def meta_none(cls):
5151
return {"key1": -1, "key2": -1}
5252

5353
def get_marks(self, idx: int) -> np.ndarray | None:

0 commit comments

Comments
 (0)